Skip to content

Commit 12419f6

Browse files
authored
[Blackwell] Fix the TMEM message heuristic (#6692)
Based on feedback from @csullivan The heuristic is also supposed to avoid using two `.x128` messages when the total workload size is 256 (elements per thread). Account for that and the reg size of each individual message.
1 parent 576e889 commit 12419f6

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

test/Conversion/tritongpu_to_llvm_blackwell.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,10 +389,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
389389

390390
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 128 : i32, "ttg.threads-per-warp" = 32 : i32} {
391391
// CHECK-LABEL: @tensor_memory_ld_128x256
392-
// CHECK-COUNT-2: tcgen05.st.sync.aligned.32x32b.x128.b32
392+
// CHECK-COUNT-4: tcgen05.st.sync.aligned.32x32b.x64.b32
393393
// CHECK-NOT: tcgen05.st
394394
// CHECK: tcgen05.wait::st.sync.aligned
395-
// CHECK-COUNT-2: tcgen05.ld.sync.aligned.32x32b.x128.b32
395+
// CHECK-COUNT-4: tcgen05.ld.sync.aligned.32x32b.x64.b32
396396
// CHECK-NOT: tcgen05.ld
397397
// CHECK: tcgen05.wait::ld.sync.aligned
398398
tt.func public @tensor_memory_ld_128x256(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,19 @@ TMemMessageTraits getTMemMessageFromAtom(const TMemAccessAtom &atom,
134134
return m;
135135
}
136136

137-
// Only allows half of the thread registers to be used for tensor memory access
138-
// to avoid register pressure. This ensures the largest tmem message width is
139-
// used for the workload without inducing spills.
140-
int getTMemMessageNarrowingFactor(const TMemAccessAtom &atom, int maxnreg) {
137+
// Narrow the TMEM message by reducing the number of registers per TMEM
138+
// instruction such that:
139+
// - No instruction uses more than half the available registers at a time.
140+
// - If the total number of registers required by the workload is more than half
141+
// of the available registers, don't use the largest TMEM message.
142+
int getTMemMessageNarrowingFactor(const TMemAccessAtom &atom,
143+
int workloadThreadRegs, int maxnreg) {
141144
const int allowedRegUsage = maxnreg / 2;
142145
int narrowingFactor = 1;
143146
while (getTMemMessageFromAtom(atom, narrowingFactor).numRegs >
144-
allowedRegUsage) {
147+
allowedRegUsage ||
148+
workloadThreadRegs > allowedRegUsage) {
149+
workloadThreadRegs /= 2;
145150
narrowingFactor *= 2;
146151
}
147152
return narrowingFactor;
@@ -381,7 +386,11 @@ void createWaitOpSt(Location loc, ConversionPatternRewriter &rewriter) {
381386
TMemMessageTraits selectTMemMessage(const TMemRuntimeInfo &info, int maxnreg) {
382387
auto atom = info.useStridedMessage ? TMemAccess16x32bx2 : TMemAccess32x32b;
383388

384-
int narrowingFactor = getTMemMessageNarrowingFactor(atom, maxnreg);
389+
int totalRegsNeeded =
390+
getEffectiveRegs(info.unpackedb16, info.useStridedMessage,
391+
info.numCols / info.numWarpGroups);
392+
int narrowingFactor =
393+
getTMemMessageNarrowingFactor(atom, totalRegsNeeded, maxnreg);
385394
auto narrowedMessage = getTMemMessageFromAtom(atom, narrowingFactor);
386395
narrowedMessage = constrainMessageFromWorkload(narrowedMessage, info,
387396
narrowedMessage.numRegs);

0 commit comments

Comments
 (0)