Skip to content

Commit 7a9c004

Browse files
authored
[BACKEND] Make sure tmem load sink pattern converges (#7627)
When we have multiple tmem load being sinked to the same point we could run into an infinite loop and have the pattern not converge.
1 parent 991152f commit 7a9c004

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

lib/Dialect/TritonGPU/Transforms/HoistTMEMAlloc.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,13 @@ class SinkTMEMLoad : public OpRewritePattern<ttng::TMEMLoadOp> {
141141
return postDomInfo.properlyPostDominates(use->getOwner(), domOp);
142142
}))
143143
return failure();
144-
if (domOp == load->getNextNode()) {
144+
// In order to not re-ordering multiple tmem load in a loop, don't sink if
145+
// all the ops between the load and the domOp are tmem loads.
146+
Operation *nextNode = load->getNextNode();
147+
while (auto tmemLoad = dyn_cast<ttng::TMEMLoadOp>(nextNode)) {
148+
nextNode = tmemLoad->getNextNode();
149+
}
150+
if (domOp == nextNode) {
145151
// The load wasn't moved.
146152
return failure();
147153
}

test/TritonGPU/hoist-tmem-alloc.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,3 +353,26 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
353353
tt.return %result1, %token1 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token
354354
}
355355
}
356+
357+
// -----
358+
359+
#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
360+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
361+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
362+
// CHECK-LABEL: @sink_multiple_tmem_load
363+
tt.func public @sink_multiple_tmem_load(%m: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, %t: !ttg.async.token) -> (tensor<128x128xf32, #blocked>, tensor<128x128xf32, #blocked>) {
364+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
365+
%c0_i32 = arith.constant 0 : i32
366+
%c1_i32 = arith.constant 1 : i32
367+
%c2_i32 = arith.constant 2 : i32
368+
%res:2 = scf.for %i = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%init0 = %cst, %init1 = %cst) -> (tensor<128x128xf32, #blocked>, tensor<128x128xf32, #blocked>) : i32 {
369+
// Any order is fine, just make sure we don't reorder them in an infinite loop.
370+
// CHECK-COUNT-2: ttng.tmem_load
371+
// CHECK: scf.yield
372+
%l0, %token_1 = ttng.tmem_load %m[%t] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
373+
%l1, %token_2 = ttng.tmem_load %m[%t] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
374+
scf.yield %l0, %l1 : tensor<128x128xf32, #blocked>, tensor<128x128xf32, #blocked>
375+
} {tt.scheduled_max_stage = 3 : i32}
376+
tt.return %res#0, %res#1 : tensor<128x128xf32, #blocked>, tensor<128x128xf32, #blocked>
377+
}
378+
}

0 commit comments

Comments
 (0)