Skip to content

Commit 1f797a5

Browse files
authored
[AMD] Fix pingpong ChainedDot for empty second memory cluster (#7694)
triton-lang/triton#7638 introduced a null pointer access (during review adjustments) if the second memory cluster is empty or if there are no memory clusters at all. Added a lit test to catch it and revert to the old logic.
1 parent d792765 commit 1f797a5

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

test/TritonGPU/amd/amd-block-pingpong-chained-dots.mlir

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
137137
#smem = #ttg.shared_memory
138138
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
139139

140-
// CHECK-LABEL: reject_chained_dots_empty_mem_cluster
140+
// CHECK-LABEL: reject_chained_dots_empty_mem_cluster_1
141141

142142
// CHECK-NOT: setprio
143143
// CHECK-NOT: barrier
144144

145-
tt.func @reject_chained_dots_empty_mem_cluster(%arg0: tensor<64x16xf16, #blocked>, %arg1: tensor<64x16x!tt.ptr<f16>, #blocked>, %arg2: i32, %arg3: i32, %arg4: tensor<128x16xf32, #mma>, %arg5: tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, %arg6: i32, %arg7: tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, %arg8: tensor<128x16xf32, #mma>, %arg9: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg10: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>, %arg11: i32, %arg12: i32, %arg13: tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>) -> tensor<128x16xf32, #mma> {
145+
tt.func @reject_chained_dots_empty_mem_cluster_1(%arg0: tensor<64x16xf16, #blocked>, %arg1: tensor<64x16x!tt.ptr<f16>, #blocked>, %arg2: i32, %arg3: i32, %arg4: tensor<128x16xf32, #mma>, %arg5: tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, %arg6: i32, %arg7: tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, %arg8: tensor<128x16xf32, #mma>, %arg9: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg10: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>, %arg11: i32, %arg12: i32, %arg13: tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>) -> tensor<128x16xf32, #mma> {
146146
%c1_i32 = arith.constant 1 : i32
147147
%c0_i32 = arith.constant 0 : i32
148148
%0 = ttg.local_alloc : () -> !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable>
@@ -164,3 +164,29 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
164164
tt.return %5#0 : tensor<128x16xf32, #mma>
165165
}
166166
}
167+
168+
// -----
169+
170+
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}>
171+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
172+
#shared = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 8, order = [0, 1]}>
173+
#smem = #ttg.shared_memory
174+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
175+
176+
// CHECK-LABEL: reject_chained_dots_empty_mem_cluster_2
177+
178+
// CHECK-NOT: setprio
179+
// CHECK-NOT: barrier
180+
181+
tt.func @reject_chained_dots_empty_mem_cluster_2(%memdesc1: !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, %memdesc2: !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, %alloc1: !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable>, %alloc2: !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable>, %arg0: tensor<64x16xf16, #blocked>, %arg1: tensor<64x16x!tt.ptr<f16>, #blocked>, %arg2: i32, %arg3: i32, %arg4: tensor<128x16xf32, #mma>, %arg5: tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, %arg6: i32, %arg7: tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, %arg8: tensor<128x16xf32, #mma>, %arg9: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg10: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>, %arg11: i32, %arg12: i32, %arg13: tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>) -> tensor<128x16xf32, #mma> {
182+
%5:8 = scf.for %arg14 = %arg3 to %arg2 step %arg3 iter_args(%arg15 = %arg4, %arg16 = %arg4, %arg17 = %arg7, %arg18 = %memdesc1, %arg19 = %memdesc1, %arg20 = %memdesc2, %arg21 = %arg0, %arg22 = %arg0) -> (tensor<128x16xf32, #mma>, tensor<128x16xf32, #mma>, tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, tensor<64x16xf16, #blocked>, tensor<64x16xf16, #blocked>) : i32 {
183+
%6 = tt.dot %arg10, %arg17, %arg15 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma>
184+
ttg.local_store %arg22, %arg20 : tensor<64x16xf16, #blocked> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>
185+
%11 = ttg.local_load %arg20 : !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
186+
%13 = tt.load %arg1 : tensor<64x16x!tt.ptr<f16>, #blocked>
187+
%10 = tt.dot %arg10, %arg17, %arg16 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma>
188+
scf.yield %10, %6, %11, %arg19, %arg20, %arg20, %13, %13 : tensor<128x16xf32, #mma>, tensor<128x16xf32, #mma>, tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, tensor<64x16xf16, #blocked>, tensor<64x16xf16, #blocked>
189+
}
190+
tt.return %5#0 : tensor<128x16xf32, #mma>
191+
}
192+
}

third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ LogicalResult Pingponger::transformChainedDotSchedule(OpBuilder &builder,
678678

679679
// Memory clusters start with either ttg.async_wait or ttg.local_store
680680
auto findNextMemoryCluster = [](Operation *op) {
681-
while (!llvm::isa_and_nonnull<ttg::AsyncWaitOp, ttg::LocalStoreOp>(op)) {
681+
while (op && !llvm::isa<ttg::AsyncWaitOp, ttg::LocalStoreOp>(op)) {
682682
op = op->getNextNode();
683683
}
684684
return op;

0 commit comments

Comments
 (0)