Skip to content

Commit 4d2434b

Browse files
authored
[Blackwell] Fix the math to calculate num reg for tmem load/store (#5991)
Tweak a bit the heuristic picking the tmem messages. The num reg estimation was not considering the number of warpgroups. This fixes performance regressions.
1 parent c1ed673 commit 4d2434b

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

test/Conversion/tritongpu_to_llvm_blackwell.mlir

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
344344

345345
// -----
346346

347-
348347
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
349348
#blocked1 = #ttg.blocked<{sizePerThread = [1, 256], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
350349
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 256, unpacked = true>
@@ -368,3 +367,23 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
368367
tt.return
369368
}
370369
}
370+
371+
// -----
372+
373+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 4], order = [1, 0]}>
374+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 256], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1]}>
375+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 256, unpacked = true>
376+
377+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 128 : i32, "ttg.threads-per-warp" = 32 : i32} {
378+
// CHECK-LABEL: @tensor_memory_ld_128x256_8_warps
379+
// CHECK: tcgen05.st.sync.aligned.32x32b.x128.b32
380+
// CHECK: tcgen05.wait::st.sync.aligned
381+
// CHECK: tcgen05.ld.sync.aligned.32x32b.x128.b32
382+
// CHECK: tcgen05.wait::ld.sync.aligned
383+
tt.func public @tensor_memory_ld_128x256_8_warps(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
384+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked1>
385+
%0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x256xf32, #blocked1>) -> !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>
386+
%20 = ttng.tmem_load %0 : !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x256xf32, #blocked1>
387+
tt.return
388+
}
389+
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,8 @@ TMemMessageTraits selectTMemMessage(const TMemRuntimeInfo &info) {
341341
auto atom = info.useStridedMessage ? TMemAccess16x32bx2 : TMemAccess32x32b;
342342

343343
int totalRegsNeeded =
344-
getEffectiveRegs(info.unpackedb16, info.useStridedMessage, info.numCols);
344+
getEffectiveRegs(info.unpackedb16, info.useStridedMessage,
345+
info.numCols / info.numWarpGroups);
345346
int narrowingFactor = getTMemMessageNarrowingFactor(totalRegsNeeded);
346347
auto narrowedMessage = getTMemMessageFromAtom(atom, narrowingFactor);
347348
narrowedMessage = constrainMessageFromWorkload(narrowedMessage, info,

0 commit comments

Comments
 (0)