Skip to content

Commit 1de9428

Browse files
authored
[BACKEND] Do not drop local_load's AsyncToken in CanonicalizeConvertFromConvert (triton-lang#6326)
Changes `CanonicalizeConvertFromConvert` to include the `AsyncToken` when rewriting the `local_load`.
1 parent 8922df9 commit 1de9428

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,8 @@ struct CanonicalizeConvertFromConvert
273273
// memory side-effects between the LocalLoad op and the ConvertLayout op
274274
rewriter.setInsertionPoint(arg);
275275
rewriter.replaceOpWithNewOp<LocalLoadOp>(op, op->getResult(0).getType(),
276-
sharedLoad.getSrc());
276+
sharedLoad.getSrc(),
277+
sharedLoad.getToken());
277278

278279
return success();
279280
}

test/TritonGPU/canonicalize.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ tt.func @test_canonicalize_convert_histogram(%arg0: tensor<256xi32, #blocked1>)
102102

103103
// CHECK-LABEL: @test_canonicalize_convert_local_load
104104
// CHECK-NOT: gpu.barrier
105-
// CHECK: %[[V:.+]] = ttg.local_load
105+
// CHECK: %[[V:.+]] = ttg.local_load {{.*}} token %arg0
106106
// CHECK-NEXT: gpu.barrier
107107
// CHECK-NEXT: tt.return %[[V]]
108108

@@ -111,9 +111,9 @@ tt.func @test_canonicalize_convert_histogram(%arg0: tensor<256xi32, #blocked1>)
111111
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
112112
#smem = #ttg.shared_memory
113113
module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.compute-capability" = 80} {
114-
tt.func @test_canonicalize_convert_local_load() -> tensor<256xi32, #blocked1> {
114+
tt.func @test_canonicalize_convert_local_load(%arg0: !ttg.async.token) -> tensor<256xi32, #blocked1> {
115115
%0 = ttg.local_alloc : () -> !ttg.memdesc<256xi32, #shared, #smem, mutable>
116-
%1 = ttg.local_load %0 : !ttg.memdesc<256xi32, #shared, #smem, mutable> -> tensor<256xi32, #blocked>
116+
%1 = ttg.local_load %0 token %arg0: !ttg.memdesc<256xi32, #shared, #smem, mutable> -> tensor<256xi32, #blocked>
117117
gpu.barrier
118118
%2 = ttg.convert_layout %1 : tensor<256xi32, #blocked> -> tensor<256xi32, #blocked1>
119119
tt.return %2 : tensor<256xi32, #blocked1>

0 commit comments

Comments
 (0)