Skip to content

Commit f35eb5b

Browse files
[AMD] Add two-cluster pingpong transform and selecting logic (#5526)
- Added two cluster pingpong, covers medium sized tile. - Added logic to select whether/which pingpong transform per given condition.
1 parent dcad5ac commit f35eb5b

File tree

2 files changed

+248
-55
lines changed

2 files changed

+248
-55
lines changed

test/TritonGPU/amd/amd-block-pingpong.mlir

Lines changed: 110 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -86,43 +86,43 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
8686
// CHECK: tt.load
8787
// CHECK: %[[SLICEA0:.+]] = ttg.local_load
8888
// CHECK: %[[SLICEB0:.+]] = ttg.local_load
89-
// CHECK: rocdl.sched.barrier 0
9089
// CHECK: gpu.barrier
90+
// CHECK: rocdl.sched.barrier 0
9191
// CHECK: rocdl.s.setprio 1
9292
// CHECK: %[[DOT0:.+]] = tt.dot %[[SLICEA0]], %[[SLICEB0]]
9393
// CHECK: rocdl.s.setprio 0
94-
// CHECK: rocdl.sched.barrier 0
9594
// CHECK: gpu.barrier
95+
// CHECK: rocdl.sched.barrier 0
9696
// CHECK: tt.load
9797
// CHECK: %[[SLICEA1:.+]] = ttg.local_load
9898
// CHECK: %[[SLICEB1:.+]] = ttg.local_load
99-
// CHECK: rocdl.sched.barrier 0
10099
// CHECK: gpu.barrier
100+
// CHECK: rocdl.sched.barrier 0
101101
// CHECK: rocdl.s.setprio 1
102102
// CHECK: %[[DOT1:.+]] = tt.dot %[[SLICEA1]], %[[SLICEB1]], %[[DOT0]]
103103
// CHECK: rocdl.s.setprio 0
104-
// CHECK: rocdl.sched.barrier 0
105104
// CHECK: gpu.barrier
105+
// CHECK: rocdl.sched.barrier 0
106106
// CHECK: %[[SLICEA2:.+]] = ttg.local_load
107107
// CHECK: %[[SLICEB2:.+]] = ttg.local_load
108108
// CHECK: %[[SLICEA3:.+]] = ttg.local_load
109109
// CHECK: %[[SLICEB3:.+]] = ttg.local_load
110-
// CHECK: rocdl.sched.barrier 0
111110
// CHECK: gpu.barrier
111+
// CHECK: rocdl.sched.barrier 0
112112
// CHECK: rocdl.s.setprio 1
113113
// CHECK: %[[DOT2:.+]] = tt.dot %[[SLICEA2]], %[[SLICEB2]], %[[DOT1]]
114114
// CHECK: rocdl.s.setprio 0
115-
// CHECK: rocdl.sched.barrier 0
116115
// CHECK: gpu.barrier
116+
// CHECK: rocdl.sched.barrier 0
117117
// CHECK: ttg.local_store
118118
// CHECK: ttg.local_store
119-
// CHECK: rocdl.sched.barrier 0
120119
// CHECK: gpu.barrier
120+
// CHECK: rocdl.sched.barrier 0
121121
// CHECK: rocdl.s.setprio 1
122122
// CHECK: tt.dot %[[SLICEA3]], %[[SLICEB3]], %[[DOT2]]
123123
// CHECK: rocdl.s.setprio 0
124-
// CHECK: rocdl.sched.barrier 0
125124
// CHECK: gpu.barrier
125+
// CHECK: rocdl.sched.barrier 0
126126
// CHECK: scf.yield
127127
// CHECK: amdgpu.cond_barrier %[[WARPLOW]]
128128

@@ -169,9 +169,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
169169
%27 = tt.load %26 : tensor<256x64x!tt.ptr<f16>, #blocked1>
170170
%28 = tt.addptr %arg13, %cst_0 : tensor<64x256x!tt.ptr<f16>, #blocked>, tensor<64x256xi32, #blocked>
171171
%29 = tt.load %28 : tensor<64x256x!tt.ptr<f16>, #blocked>
172-
%30 = ttg.local_load %arg15 : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
173-
%31 = ttg.local_load %arg16 : !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
174-
%32 = tt.dot %30, %31, %arg11 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<256x256xf32, #mma>
172+
%30 = ttg.local_load %arg15 : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
173+
%31 = ttg.local_load %arg16 : !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
174+
%32 = tt.dot %30, %31, %arg11 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x256xf32, #mma>
175175
%33 = arith.addi %arg14, %c1_i32 : i32
176176
%34 = arith.cmpi slt, %33, %c1_i32 : i32
177177
%35 = arith.select %34, %33, %c0_i32 : i32
@@ -189,6 +189,105 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
189189

190190
// -----
191191

192+
// CHECK: gpu.barrier
193+
// CHECK: %[[IDX:.+]] = rocdl.workitem.id.x
194+
// CHECK: %[[XDIV:.+]] = arith.divsi %[[IDX]]
195+
// CHECK: %[[WARPLOW:.+]] = arith.cmpi eq, %[[XDIV]]
196+
// CHECK: %[[WARPHIGH:.+]] = arith.cmpi ne, %[[XDIV]]
197+
// CHECK: amdgpu.cond_barrier %[[WARPHIGH]]
198+
// CHECK: scf.for
199+
200+
// CHECK: %[[SLICEA0:.+]] = ttg.local_load
201+
// CHECK: %[[SLICEB0:.+]] = ttg.local_load
202+
// CHECK: rocdl.sched.barrier 0
203+
// CHECK: tt.load
204+
// CHECK: rocdl.sched.barrier 0
205+
// CHECK: %[[SLICEA1:.+]] = ttg.local_load
206+
// CHECK: %[[SLICEB1:.+]] = ttg.local_load
207+
// CHECK: rocdl.sched.barrier 0
208+
// CHECK: tt.load
209+
// CHECK: rocdl.s.barrier
210+
// CHECK: rocdl.sched.barrier 0
211+
// CHECK: rocdl.s.setprio 1
212+
// CHECK: %[[DOT0:.+]] = tt.dot %[[SLICEA0]], %[[SLICEB0]]
213+
// CHECK: rocdl.s.setprio 0
214+
// CHECK: gpu.barrier
215+
// CHECK: rocdl.sched.barrier 0
216+
// CHECK: ttg.local_store
217+
// CHECK: ttg.local_store
218+
// CHECK: gpu.barrier
219+
// CHECK: rocdl.sched.barrier 0
220+
// CHECK: rocdl.s.setprio 1
221+
// CHECK: %[[DOT1:.+]] = tt.dot %[[SLICEA1]], %[[SLICEB1]], %[[DOT0]]
222+
// CHECK: rocdl.s.setprio 0
223+
// CHECK: gpu.barrier
224+
// CHECK: rocdl.sched.barrier 0
225+
// CHECK: scf.yield
226+
// CHECK: amdgpu.cond_barrier %[[WARPLOW]]
227+
228+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
229+
#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>
230+
#loc = loc("/home/jung/rocm/triton/python/perf-kernels/tools/tune_gemm/matmul_kernel.py":6:0)
231+
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}>
232+
#shared1 = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>
233+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
234+
tt.func public @pingpong_medium(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
235+
%cst = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma>
236+
%c1_i32 = arith.constant 1 : i32
237+
%cst_0 = arith.constant dense<64> : tensor<64x128xi32, #blocked>
238+
%cst_1 = arith.constant dense<64> : tensor<256x64xi32, #blocked1>
239+
%c0_i32 = arith.constant 0 : i32
240+
%c64_i32 = arith.constant 64 : i32
241+
%0 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x1x!tt.ptr<f16>, #blocked1>
242+
%1 = tt.get_program_id x : i32
243+
%2 = tt.splat %1 : i32 -> tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
244+
%3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
245+
%4 = arith.addi %2, %3 : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
246+
%5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1>
247+
%6 = tt.splat %arg6 : i32 -> tensor<256x1xi32, #blocked1>
248+
%7 = arith.muli %5, %6 : tensor<256x1xi32, #blocked1>
249+
%8 = tt.addptr %0, %7 : tensor<256x1x!tt.ptr<f16>, #blocked1>, tensor<256x1xi32, #blocked1>
250+
%9 = tt.broadcast %8 : tensor<256x1x!tt.ptr<f16>, #blocked1> -> tensor<256x64x!tt.ptr<f16>, #blocked1>
251+
%10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
252+
%11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1>
253+
%12 = tt.broadcast %11 : tensor<1x64xi32, #blocked1> -> tensor<256x64xi32, #blocked1>
254+
%13 = tt.addptr %9, %12 : tensor<256x64x!tt.ptr<f16>, #blocked1>, tensor<256x64xi32, #blocked1>
255+
%14 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>, #blocked>
256+
%15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
257+
%16 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked>
258+
%17 = tt.addptr %14, %16 : tensor<64x1x!tt.ptr<f16>, #blocked>, tensor<64x1xi32, #blocked>
259+
%18 = tt.broadcast %17 : tensor<64x1x!tt.ptr<f16>, #blocked> -> tensor<64x128x!tt.ptr<f16>, #blocked>
260+
%19 = tt.splat %arg7 : i32 -> tensor<64x128xi32, #blocked>
261+
%20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr<f16>, #blocked>, tensor<64x128xi32, #blocked>
262+
%21 = ttg.local_alloc : () -> !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable>
263+
%22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable>
264+
%23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>
265+
%24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>
266+
%25:6 = scf.for %arg10 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg11 = %cst, %arg12 = %13, %arg13 = %20, %arg14 = %c0_i32, %arg15 = %23, %arg16 = %24) -> (tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr<f16>, #blocked1>, tensor<64x128x!tt.ptr<f16>, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>) : i32 {
267+
%26 = tt.addptr %arg12, %cst_1 : tensor<256x64x!tt.ptr<f16>, #blocked1>, tensor<256x64xi32, #blocked1>
268+
%27 = tt.load %26 : tensor<256x64x!tt.ptr<f16>, #blocked1>
269+
%28 = tt.addptr %arg13, %cst_0 : tensor<64x128x!tt.ptr<f16>, #blocked>, tensor<64x128xi32, #blocked>
270+
%29 = tt.load %28 : tensor<64x128x!tt.ptr<f16>, #blocked>
271+
%30 = ttg.local_load %arg15 : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
272+
%31 = ttg.local_load %arg16 : !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
273+
%32 = tt.dot %30, %31, %arg11 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x128xf32, #mma>
274+
%33 = arith.addi %arg14, %c1_i32 : i32
275+
%34 = arith.cmpi slt, %33, %c1_i32 : i32
276+
%35 = arith.select %34, %33, %c0_i32 : i32
277+
%36 = ttg.memdesc_subview %21[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>
278+
ttg.local_store %27, %36 : tensor<256x64xf16, #blocked1> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>
279+
%37 = ttg.memdesc_subview %22[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>
280+
ttg.local_store %29, %37 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>
281+
scf.yield %32, %26, %28, %35, %36, %37 : tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr<f16>, #blocked1>, tensor<64x128x!tt.ptr<f16>, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>
282+
}
283+
ttg.local_dealloc %21 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable>
284+
ttg.local_dealloc %22 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable>
285+
tt.return
286+
}
287+
}
288+
289+
// -----
290+
192291
// CHECK-LABEL: pingpong_reject
193292
// CHECK-COUNT-2: local_load
194293
// CHECK-NOT: local_load

0 commit comments

Comments
 (0)