Skip to content

Commit 6377474

Browse files
authored
[Backend] Codegen warpId to 0 when there 1 contextual warp (#6823)
This allows a bunch of code to fold away trivially, especially in the MMA and load partitions of warp specialized kernels.
1 parent 4b9efc5 commit 6377474

File tree

3 files changed

+52
-2
lines changed

3 files changed

+52
-2
lines changed

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,16 @@ std::pair<Value, Value> getLaneAndWarpId(OpBuilder &rewriter, Location loc) {
207207
Value tid = getThreadId(rewriter, loc);
208208
int threadsPerWarp = triton::gpu::lookupThreadsPerWarp(rewriter);
209209
Value warpSizeVal = b.i32_val(threadsPerWarp);
210-
211210
Value laneId = b.urem(tid, warpSizeVal);
212-
Value warpId = b.udiv(tid, warpSizeVal);
211+
212+
// If there is only one warp, the warp ID is always 0.
213+
Operation *lookupPt = &rewriter.getInsertionBlock()->front();
214+
Value warpId;
215+
if (triton::gpu::lookupNumWarps(lookupPt) == 1)
216+
warpId = b.i32_val(0);
217+
else
218+
warpId = b.udiv(tid, warpSizeVal);
219+
213220
return {laneId, warpId};
214221
}
215222

test/Conversion/nvgpu_to_llvm.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,3 +211,40 @@ llvm.func @warpid_warp_specialize() {
211211
}
212212

213213
}
214+
215+
// -----
216+
217+
module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} {
218+
219+
// CHECK-LABEL: @one_warp
220+
tt.func @one_warp() -> i32 {
221+
// CHECK-NEXT: [[C0:%.*]] = llvm.mlir.constant(0 : i32)
222+
%0 = nvgpu.warp_id
223+
// CHECK-NEXT: return [[C0]]
224+
tt.return %0 : i32
225+
}
226+
227+
}
228+
229+
// -----
230+
231+
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
232+
233+
// CHECK-LABEL: @one_contextual_warp
234+
tt.func @one_contextual_warp() {
235+
ttg.warp_specialize()
236+
default {
237+
ttg.warp_yield
238+
}
239+
// CHECK: partition0
240+
partition0() num_warps(1) {
241+
// CHECK-NEXT: [[C0:%.*]] = llvm.mlir.constant(0 : i32)
242+
%0 = nvgpu.warp_id
243+
// CHECK-NEXT: "use"([[C0]])
244+
"use"(%0) : (i32) -> ()
245+
ttg.warp_return
246+
} : () -> ()
247+
tt.return
248+
}
249+
250+
}

third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,12 @@ class WarpIdOpPattern : public OpRewritePattern<ttn::WarpIdOp> {
232232
auto loc = op.getLoc();
233233
auto b = TritonLLVMOpBuilder(loc, rewriter);
234234

235+
if (triton::gpu::lookupNumWarps(op) == 1) {
236+
// If there is only one warp, the warp ID is always 0.
237+
rewriter.replaceOp(op, b.i32_val(0));
238+
return success();
239+
}
240+
235241
// If this is inside a warp specialize op, compute the relative thread ID
236242
// within the warp group.
237243
Value tid = rewriter.create<NVVM::ThreadIdXOp>(loc, i32_ty);

0 commit comments

Comments
 (0)