Skip to content

Commit 062e38e

Browse files
authored
[AMD] Add pingpong transformation for chained dot schedule (#7638)
Adds support to enable pingpong for loops scheduled with the new `ChainedDotSchedule` introduced by triton-lang/triton#7601. The schedule already places the ops in the correct order so we just have to insert the sync ops to ensure proper pingpong'ing.
1 parent bae5ff9 commit 062e38e

File tree

2 files changed

+250
-2
lines changed

2 files changed

+250
-2
lines changed
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
// RUN: triton-opt %s -split-input-file --tritonamdgpu-block-pingpong="num-stages=4" | FileCheck %s
2+
3+
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}>
4+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
5+
#shared = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 8, order = [0, 1]}>
6+
#smem = #ttg.shared_memory
7+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
8+
9+
// CHECK-LABEL: chained_dots_async_loads
10+
11+
// CHECK: scf.for
12+
// CHECK: rocdl.s.setprio 0
13+
// Compute Cluster1
14+
// CHECK: tt.dot
15+
// CHECK: rocdl.s.setprio 1
16+
// CHECK: ttg.async_wait
17+
// CHECK: rocdl.sched.barrier 0
18+
// MemoryCluster2
19+
// CHECK: ttg.local_load
20+
// CHECK: ttg.async_copy_global_to_local
21+
// CHECK: ttg.async_commit_group
22+
// CHECK: rocdl.sched.barrier 0
23+
// CHECK: rocdl.s.barrier
24+
// CHECK: rocdl.s.setprio 0
25+
// Compute Cluster2
26+
// CHECK: tt.dot
27+
// CHECK: rocdl.s.setprio 1
28+
// CHECK: ttg.async_wait
29+
// CHECK: rocdl.sched.barrier 0
30+
// Memory Cluster2
31+
// CHECK: ttg.local_load
32+
// CHECK: ttg.async_copy_global_to_local
33+
// CHECK: ttg.async_commit_group
34+
// CHECK: rocdl.sched.barrier 0
35+
// CHECK: rocdl.s.barrier
36+
// CHECK-NEXT: scf.yield
37+
38+
tt.func @chained_dots_async_loads(%arg0: tensor<64x16x!tt.ptr<f16>, #blocked>, %arg1: i32, %arg2: i32, %arg3: !ttg.async.token, %arg4: tensor<128x16xf32, #mma>, %arg5: tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, %arg6: i32, %arg7: tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, %arg8: tensor<128x16xf32, #mma>, %arg9: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg10: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>, %arg11: i32, %arg12: i32, %arg13: tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>) -> tensor<128x16xf32, #mma> {
39+
%c1_i32 = arith.constant 1 : i32
40+
%c0_i32 = arith.constant 0 : i32
41+
%0 = ttg.local_alloc : () -> !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable>
42+
%1 = ttg.local_alloc : () -> !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable>
43+
%2 = ttg.memdesc_index %1, %c0_i32 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>
44+
%3 = ttg.memdesc_index %0, %c0_i32 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>
45+
%4 = ttg.memdesc_index %1, %c1_i32 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>
46+
%5:9 = scf.for %arg14 = %c0_i32 to %arg1 step %arg2 iter_args(%arg15 = %arg4, %arg16 = %arg4, %arg17 = %arg7, %arg18 = %arg3, %arg19 = %arg3, %arg20 = %2, %arg21 = %4, %arg22 = %arg3, %arg23 = %3) -> (tensor<128x16xf32, #mma>, tensor<128x16xf32, #mma>, tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, !ttg.async.token, !ttg.async.token, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.async.token, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>) : i32 {
47+
%6 = tt.dot %arg10, %arg17, %arg15 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma>
48+
%7 = ttg.async_wait %arg18 {num = 0 : i32}
49+
%8 = ttg.local_load %arg20 token %7 : !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
50+
%9 = ttg.memdesc_index %0, %arg6 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>
51+
%10 = ttg.async_copy_global_to_local %arg0, %9 : tensor<64x16x!tt.ptr<f16>, #blocked> -> <64x16xf16, #shared, #smem, mutable, 2x64x16>
52+
%11 = ttg.async_commit_group %10
53+
%12 = tt.dot %arg10, %8, %arg16 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma>
54+
%13 = ttg.async_wait %arg22 {num = 0 : i32}
55+
%14 = ttg.local_load %arg23 token %13 : !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
56+
%15 = ttg.memdesc_index %1, %arg6 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>
57+
%16 = ttg.async_copy_global_to_local %arg0, %15 : tensor<64x16x!tt.ptr<f16>, #blocked> -> <64x16xf16, #shared, #smem, mutable, 2x64x16>
58+
%17 = ttg.async_commit_group %16
59+
scf.yield %12, %6, %14, %arg19, %17, %arg21, %15, %11, %9 : tensor<128x16xf32, #mma>, tensor<128x16xf32, #mma>, tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, !ttg.async.token, !ttg.async.token, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.async.token, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>
60+
}
61+
ttg.local_dealloc %1 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable>
62+
ttg.local_dealloc %0 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable>
63+
tt.return %5#0 : tensor<128x16xf32, #mma>
64+
}
65+
}
66+
67+
// -----
68+
69+
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}>
70+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
71+
#shared = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 8, order = [0, 1]}>
72+
#smem = #ttg.shared_memory
73+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
74+
75+
// CHECK-LABEL: chained_dots_tt_loads
76+
77+
// CHECK-NOT: rocdl.s
78+
// CHECK: scf.for
79+
// CHECK: rocdl.s.setprio 0
80+
// Compute Cluster1
81+
// CHECK: tt.dot
82+
// CHECK: rocdl.s.setprio 1
83+
// CHECK: gpu.barrier
84+
// CHECK: rocdl.sched.barrier 0
85+
// MemoryCluster2
86+
// CHECK: ttg.local_store
87+
// CHECK: ttg.local_load
88+
// CHECK: tt.load
89+
// CHECK: rocdl.sched.barrier 0
90+
// CHECK: rocdl.s.barrier
91+
// CHECK: rocdl.s.setprio 0
92+
// Compute Cluster2
93+
// CHECK: tt.dot
94+
// CHECK: rocdl.s.setprio 1
95+
// CHECK: gpu.barrier
96+
// CHECK: rocdl.sched.barrier 0
97+
// Memory Cluster2
98+
// CHECK: ttg.local_store
99+
// CHECK: ttg.local_load
100+
// CHECK: tt.load
101+
// CHECK: rocdl.sched.barrier 0
102+
// CHECK: rocdl.s.barrier
103+
// CHECK-NEXT: scf.yield
104+
105+
tt.func @chained_dots_tt_loads(%arg0: tensor<64x16xf16, #blocked>, %arg1: tensor<64x16x!tt.ptr<f16>, #blocked>, %arg2: i32, %arg3: i32, %arg4: tensor<128x16xf32, #mma>, %arg5: tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, %arg6: i32, %arg7: tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, %arg8: tensor<128x16xf32, #mma>, %arg9: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg10: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>, %arg11: i32, %arg12: i32, %arg13: tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>) -> tensor<128x16xf32, #mma> {
106+
%c1_i32 = arith.constant 1 : i32
107+
%c0_i32 = arith.constant 0 : i32
108+
%0 = ttg.local_alloc : () -> !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable>
109+
%1 = ttg.local_alloc : () -> !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable>
110+
%2 = ttg.memdesc_index %1, %c0_i32 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>
111+
%3 = ttg.memdesc_index %0, %c0_i32 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>
112+
%4 = ttg.memdesc_index %1, %c1_i32 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>
113+
%5:8 = scf.for %arg14 = %c0_i32 to %arg2 step %arg3 iter_args(%arg15 = %arg4, %arg16 = %arg4, %arg17 = %arg7, %arg18 = %2, %arg19 = %4, %arg20 = %3, %arg21 = %arg0, %arg22 = %arg0) -> (tensor<128x16xf32, #mma>, tensor<128x16xf32, #mma>, tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, tensor<64x16xf16, #blocked>, tensor<64x16xf16, #blocked>) : i32 {
114+
%6 = tt.dot %arg10, %arg17, %arg15 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma>
115+
ttg.local_store %arg21, %arg18 : tensor<64x16xf16, #blocked> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>
116+
%7 = ttg.local_load %arg18 : !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
117+
%8 = ttg.memdesc_index %0, %arg6 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>
118+
%9 = tt.load %arg1 : tensor<64x16x!tt.ptr<f16>, #blocked>
119+
%10 = tt.dot %arg10, %7, %arg16 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma>
120+
ttg.local_store %arg22, %arg20 : tensor<64x16xf16, #blocked> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>
121+
%11 = ttg.local_load %arg20 : !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
122+
%12 = ttg.memdesc_index %1, %arg6 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>
123+
%13 = tt.load %arg1 : tensor<64x16x!tt.ptr<f16>, #blocked>
124+
scf.yield %10, %6, %11, %arg19, %12, %8, %9, %13 : tensor<128x16xf32, #mma>, tensor<128x16xf32, #mma>, tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, tensor<64x16xf16, #blocked>, tensor<64x16xf16, #blocked>
125+
}
126+
ttg.local_dealloc %1 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable>
127+
ttg.local_dealloc %0 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable>
128+
tt.return %5#0 : tensor<128x16xf32, #mma>
129+
}
130+
}
131+
132+
// -----
133+
134+
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}>
135+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
136+
#shared = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 8, order = [0, 1]}>
137+
#smem = #ttg.shared_memory
138+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
139+
140+
// CHECK-LABEL: reject_chained_dots_empty_mem_cluster
141+
142+
// CHECK-NOT: setprio
143+
// CHECK-NOT: barrier
144+
145+
tt.func @reject_chained_dots_empty_mem_cluster(%arg0: tensor<64x16xf16, #blocked>, %arg1: tensor<64x16x!tt.ptr<f16>, #blocked>, %arg2: i32, %arg3: i32, %arg4: tensor<128x16xf32, #mma>, %arg5: tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, %arg6: i32, %arg7: tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, %arg8: tensor<128x16xf32, #mma>, %arg9: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg10: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>, %arg11: i32, %arg12: i32, %arg13: tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>) -> tensor<128x16xf32, #mma> {
146+
%c1_i32 = arith.constant 1 : i32
147+
%c0_i32 = arith.constant 0 : i32
148+
%0 = ttg.local_alloc : () -> !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable>
149+
%1 = ttg.local_alloc : () -> !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable>
150+
%2 = ttg.memdesc_index %1, %c0_i32 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>
151+
%3 = ttg.memdesc_index %0, %c0_i32 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>
152+
%4 = ttg.memdesc_index %1, %c1_i32 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>
153+
%5:8 = scf.for %arg14 = %c0_i32 to %arg2 step %arg3 iter_args(%arg15 = %arg4, %arg16 = %arg4, %arg17 = %arg7, %arg18 = %2, %arg19 = %4, %arg20 = %3, %arg21 = %arg0, %arg22 = %arg0) -> (tensor<128x16xf32, #mma>, tensor<128x16xf32, #mma>, tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, tensor<64x16xf16, #blocked>, tensor<64x16xf16, #blocked>) : i32 {
154+
%6 = tt.dot %arg10, %arg17, %arg15 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma>
155+
%10 = tt.dot %arg10, %arg17, %arg16 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma>
156+
ttg.local_store %arg22, %arg20 : tensor<64x16xf16, #blocked> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>
157+
%11 = ttg.local_load %arg20 : !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
158+
%12 = ttg.memdesc_index %1, %arg6 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>
159+
%13 = tt.load %arg1 : tensor<64x16x!tt.ptr<f16>, #blocked>
160+
scf.yield %10, %6, %11, %arg19, %12, %12, %13, %13 : tensor<128x16xf32, #mma>, tensor<128x16xf32, #mma>, tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, tensor<64x16xf16, #blocked>, tensor<64x16xf16, #blocked>
161+
}
162+
ttg.local_dealloc %1 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable>
163+
ttg.local_dealloc %0 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable>
164+
tt.return %5#0 : tensor<128x16xf32, #mma>
165+
}
166+
}

third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class Pingponger {
8585
Location loc);
8686
LogicalResult transformTwoClusterWithAsyncAndAll(OpBuilder &builder,
8787
Location loc);
88+
LogicalResult transformChainedDotSchedule(OpBuilder &builder, Location loc);
8889
void addAsymmetricSyncToLoop(OpBuilder &builder, Location loc);
8990
void updateOpInsertion(Operation *Op);
9091
void appendOp(Operation *Op);
@@ -666,6 +667,73 @@ LogicalResult Pingponger::transformTwoClusterWithAsyncAndAll(OpBuilder &builder,
666667
return success();
667668
}
668669

670+
// For ChainedDots with num_stage==4 the pipeliner already places ops in the
671+
// correct order to allow for efficient pingpong. The loop contains 2 pairs of
672+
// compute and memory clusters so we only have to place barriers/sched.barriers
673+
// at the bounaries and give higher priority to memory clusters
674+
// See StreamPipeliner.cpp:ChainedDotSchedule for details about the schedule
675+
LogicalResult Pingponger::transformChainedDotSchedule(OpBuilder &builder,
676+
Location loc) {
677+
assert(dotOps.size() == 2);
678+
679+
// Memory clusters start with either ttg.async_wait or ttg.local_store
680+
auto findNextMemoryCluster = [](Operation *op) {
681+
while (!llvm::isa_and_nonnull<ttg::AsyncWaitOp, ttg::LocalStoreOp>(op)) {
682+
op = op->getNextNode();
683+
}
684+
return op;
685+
};
686+
687+
std::array memoryClusterStartOps = {findNextMemoryCluster(dotOps[0]),
688+
findNextMemoryCluster(dotOps[1])};
689+
690+
if (llvm::is_contained(memoryClusterStartOps, nullptr) ||
691+
memoryClusterStartOps[0] == memoryClusterStartOps[1]) {
692+
LDBG("ChainedDot pingpong requires memory operations in both memory "
693+
"clusters");
694+
return failure();
695+
}
696+
697+
builder.setInsertionPointToStart(forOp.getBody());
698+
// ComputeCluster 1
699+
updateOpInsertion(dotOps[0]);
700+
prependOp(builder.create<ROCDL::SetPrioOp>(loc, lowPriority), false);
701+
702+
// MemoryCluster 1
703+
updateOpInsertion(memoryClusterStartOps[0]);
704+
prependOp(builder.create<ROCDL::SetPrioOp>(loc, highPriority), false);
705+
if (llvm::isa<ttg::AsyncWaitOp>(memoryClusterStartOps[0])) {
706+
// Only append a sched barrier because membar adds a barrier after asyncwait
707+
appendOp(builder.create<ROCDL::SchedBarrier>(loc, 0));
708+
} else {
709+
prependOp(builder.create<gpu::BarrierOp>(loc), false);
710+
prependOp(builder.create<ROCDL::SchedBarrier>(loc, 0), false);
711+
}
712+
713+
// ComputeCluster2
714+
updateOpInsertion(dotOps[1]);
715+
prependOp(builder.create<ROCDL::SchedBarrier>(loc, 0), false);
716+
prependOp(builder.create<ROCDL::SBarrierOp>(loc), false);
717+
prependOp(builder.create<ROCDL::SetPrioOp>(loc, lowPriority), false);
718+
719+
// MemoryCluster2
720+
updateOpInsertion(memoryClusterStartOps[1]);
721+
prependOp(builder.create<ROCDL::SetPrioOp>(loc, highPriority), false);
722+
if (llvm::isa<ttg::AsyncWaitOp>(memoryClusterStartOps[1])) {
723+
// Only append a sched barrier because membar adds a barrier after asyncwait
724+
appendOp(builder.create<ROCDL::SchedBarrier>(loc, 0));
725+
} else {
726+
prependOp(builder.create<gpu::BarrierOp>(loc), false);
727+
prependOp(builder.create<ROCDL::SchedBarrier>(loc, 0), false);
728+
}
729+
730+
updateOpInsertion(lastInsertedOp->getBlock()->getTerminator());
731+
prependOp(builder.create<ROCDL::SchedBarrier>(loc, 0), false);
732+
prependOp(builder.create<ROCDL::SBarrierOp>(loc), false);
733+
734+
return success();
735+
}
736+
669737
// This pingpong variant tries to construct one memory cluster and one
670738
// dot cluster. Instead of slice the tile, it is supposed to use half
671739
// sized tile_K and use num_stages=3 to prefetch and hide the buffer
@@ -809,10 +877,24 @@ void Pingponger::getDotPingponged() {
809877
// tightly scheduling the latencies.
810878

811879
int64_t numOfDotLikeOps = scaledDotOps.size() + dotOps.size();
812-
if (numOfDotLikeOps != 1) {
813-
LDBG("Only handle a single of either dot or dot_scaled op");
880+
881+
if (numOfDotLikeOps < 1 || numOfDotLikeOps > 2) {
882+
LDBG("Only handle one or two dotlike ops");
814883
return;
815884
}
885+
886+
if (numOfDotLikeOps == 2) {
887+
if (numStages != 4)
888+
return;
889+
890+
if (transformChainedDotSchedule(builder, loc).failed()) {
891+
LDBG("Encountered failure when trying the ChainedDot ping pong "
892+
"cluster transformation");
893+
return;
894+
}
895+
addAsymmetricSyncToLoop(builder, loc);
896+
}
897+
816898
useAsyncCopy = (asyncCopyOps.size() > 0);
817899
int64_t gloadSize = useAsyncCopy ? asyncCopyOps.size() : gLoadOps.size();
818900
int64_t dotSize =

0 commit comments

Comments
 (0)