Skip to content

Commit a2d179d

Browse files
authored
[Blackwell] Fix perf regression (#7643)
Somehow optimizing this code when we know the number of warp groups is 1 results in major performance regressions...
1 parent cf399b4 commit a2d179d

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -308,14 +308,10 @@ void calculateAddressAndEmitTmemMessage(
308308

309309
TritonLLVMOpBuilder b(loc, rewriter);
310310
Value warpId = rewriter.create<nvgpu::WarpIdOp>(loc);
311-
Value warpIdInGroup, warpGroupId;
312-
if (info.numWarpGroups == 1) {
313-
warpIdInGroup = warpId;
314-
warpGroupId = b.i32_val(0);
315-
} else {
316-
warpIdInGroup = b.urem(warpId, b.i32_val(4));
317-
warpGroupId = b.udiv(warpId, b.i32_val(4));
318-
}
311+
// Note: optimizing this when we know `info.numWarpGroups` is 1 can result in
312+
// performance regressions.
313+
Value warpIdInGroup = b.urem(warpId, b.i32_val(4));
314+
Value warpGroupId = b.udiv(warpId, b.i32_val(4));
319315

320316
// When split along M, blockM=128 and num_warps=8, and a strided message is
321317
// selected such that all 8 warps read a 16 rows of a block at a time.

0 commit comments

Comments
 (0)