Skip to content

Commit 09675e5

Browse files
authored
[AMD] Disallow reorder tt.load over gpu.barrier (triton-lang#4735)
1 parent c99c214 commit 09675e5

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

test/TritonGPU/amd/amd-reorder-instructions.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,3 +830,23 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
830830
tt.return
831831
}
832832
}
833+
834+
// -----
835+
836+
// CHECK-LABEL: anchor_barrier
837+
// CHECK: gpu.barrier
838+
// CHECK: tt.load %arg0 : tensor<32x32x!tt.ptr<f16>, #blocked>
839+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
840+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
841+
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}>
842+
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
843+
tt.func public @anchor_barrier(%arg0: tensor<32x32x!tt.ptr<f16>, #blocked>) attributes {noinline = false} {
844+
%0 = triton_gpu.local_alloc : () -> !tt.memdesc<4x128x64xf16, #shared, mutable>
845+
gpu.barrier
846+
%2 = tt.load %arg0 : tensor<32x32x!tt.ptr<f16>, #blocked>
847+
%1 = triton_gpu.local_alloc %2 : (tensor<32x32xf16, #blocked>) -> !tt.memdesc<4x128x64xf16, #shared, mutable>
848+
triton_gpu.local_dealloc %0 : !tt.memdesc<4x128x64xf16, #shared, mutable>
849+
triton_gpu.local_dealloc %1 : !tt.memdesc<4x128x64xf16, #shared, mutable>
850+
tt.return
851+
}
852+
}

third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ findEarlyInsertionPoint(Block *block, Operation *move) {
5050
// Atomics used for global synchronization.
5151
if (isa<triton::AtomicRMWOp, triton::AtomicCASOp>(wop))
5252
ipnt = bi;
53+
// Break at barrier
54+
if (isa<gpu::BarrierOp>(wop))
55+
ipnt = bi;
5356
// Break at loops.
5457
if (isa<scf::ForOp, scf::WhileOp>(wop))
5558
ipnt = bi;

0 commit comments

Comments
 (0)