Skip to content

Commit 1bd811a

Browse files
authored
[BACKEND] Fix wrong check in tmem_alloc canonicalization pattern (#7719)
1 parent cfe3dd0 commit 1bd811a

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

lib/Dialect/TritonGPU/Transforms/HoistTMEMAlloc.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ class CombineTMEMStoreAndAlloc : public OpRewritePattern<ttng::TMEMStoreOp> {
177177
auto alloc = store.getDep().getDefiningOp<TMEMTokenAllocOp>();
178178
if (!alloc)
179179
return failure();
180-
if (store.getSrc() != alloc.getResult())
180+
if (store.getDst() != alloc.getResult())
181181
return failure();
182182
if (alloc->getBlock() != store->getBlock())
183183
return failure();

test/TritonGPU/hoist-tmem-alloc.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,13 +341,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
341341
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
342342
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
343343
tt.func public @forward_tmem_load(%m: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, %t: !ttg.async.token) -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) {
344+
%true = arith.constant true
344345
%result, %token0 = ttng.tmem_load %m[%t] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
345346
// HOIST-IF-LABEL: @forward_tmem_load
346347
// HOIST-IF-SAME: %[[ARG0:.+]]: !ttg.memdesc<128x128xf32,
347348
// HOIST-IF-SAME: %[[ARG1:.+]]: !ttg.async.token
348349
// HOIST-IF-NEXT: tt.return %[[ARG0]], %[[ARG1]]
349-
%result1, %token1 = ttng.tmem_alloc %result : (tensor<128x128xf32, #blocked>) -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token)
350-
tt.return %result1, %token1 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token
350+
%result1, %token1 = ttng.tmem_alloc : () -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token)
351+
%token2 = ttng.tmem_store %result, %result1[%token1], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
352+
tt.return %result1, %token2 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token
351353
}
352354
}
353355

0 commit comments

Comments
 (0)