File tree Expand file tree Collapse file tree 1 file changed +4
-8
lines changed
third_party/nvidia/lib/TritonNVIDIAGPUToLLVM Expand file tree Collapse file tree 1 file changed +4
-8
lines changed Original file line number Diff line number Diff line change @@ -308,14 +308,10 @@ void calculateAddressAndEmitTmemMessage(
308
308
309
309
TritonLLVMOpBuilder b (loc, rewriter);
310
310
Value warpId = rewriter.create <nvgpu::WarpIdOp>(loc);
311
- Value warpIdInGroup, warpGroupId;
312
- if (info.numWarpGroups == 1 ) {
313
- warpIdInGroup = warpId;
314
- warpGroupId = b.i32_val (0 );
315
- } else {
316
- warpIdInGroup = b.urem (warpId, b.i32_val (4 ));
317
- warpGroupId = b.udiv (warpId, b.i32_val (4 ));
318
- }
311
+ // Note: optimizing this when we know `info.numWarpGroups` is 1 can result in
312
+ // performance regressions.
313
+ Value warpIdInGroup = b.urem (warpId, b.i32_val (4 ));
314
+ Value warpGroupId = b.udiv (warpId, b.i32_val (4 ));
319
315
320
316
// When split along M, blockM=128 and num_warps=8, and a strided message is
321
317
// selected such that all 8 warps read a 16 rows of a block at a time.
You can’t perform that action at this time.
0 commit comments