Skip to content

Commit 70d83a4

Browse files
authored
[AMD] Add ChainedDotSchedule to StreamPipeliner (#7601)
Adds a new scheduling variant which kicks in for loop which have 2 chained dots and `num_stages==4`. It places the two dots in consecutive stages so we can interleave operations using the result of the first dot with both dots in the loop, a pseudo example IR: ``` %1 = tt.dot ... %2 = arith.addf %1, %arg1 %3 = arith.subf %2, %arg2 %4 = tt.dot %X, %Y, %3 ``` Which could result in the following pseudo schedule (ignoring mem ops) to interleave with both dots: ``` stage N, Cluster0: [%1 = tt.dot, %3 = arith.subf] stage N+1, Cluster1: [%4 = tt.dot, %2 = arith.addf] ``` As a first step the schedule splits the op chain between dot1 and dot2 when it encounters an operation which has more than 2 users. This aims to avoid adding too many loop carried dependencies but does not guarantee a good work balance between the two clusters. In future PRs we might make this more sophisticated.
1 parent 24b1096 commit 70d83a4

File tree

2 files changed

+475
-3
lines changed

2 files changed

+475
-3
lines changed
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline="num_stages=4 use_async_copy=1" -canonicalize | FileCheck %s
2+
3+
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}>
4+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
5+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
6+
// CHECK-LABEL: tt.func @direct_chained_dots
7+
8+
// We have no ops between the dots so we just check that dot and memory ops are in the correct order and check if basic pipelining (prologue, epilogue) is working correctly.
9+
// CHECK-COUNT-2: ttg.local_load
10+
// CHECK: scf.for
11+
// CHECK: tt.dot
12+
// CHECK: ttg.async_copy_global_to_local
13+
// CHECK: tt.dot
14+
// CHECK: ttg.async_wait
15+
// CHECK: ttg.local_load
16+
// CHECK: scf.yield
17+
// CHECK: ttg.async_wait
18+
// CHECK: ttg.local_load
19+
20+
tt.func @direct_chained_dots(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>, %arg3: i32, %arg4: i32) -> tensor<128x16xf32, #mma> {
21+
%c0_i32 = arith.constant 0 : i32
22+
%cst = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma>
23+
%0 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<1x16x!tt.ptr<f16>, #blocked>
24+
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
25+
%2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked>
26+
%3 = tt.broadcast %0 : tensor<1x16x!tt.ptr<f16>, #blocked> -> tensor<64x16x!tt.ptr<f16>, #blocked>
27+
%4 = tt.broadcast %2 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked>
28+
%5 = tt.addptr %3, %4 : tensor<64x16x!tt.ptr<f16>, #blocked>, tensor<64x16xi32, #blocked>
29+
%6 = scf.for %arg6 = %c0_i32 to %arg3 step %arg4 iter_args(%arg5 = %cst) -> (tensor<128x16xf32, #mma>) : i32 {
30+
%7 = tt.load %5 : tensor<64x16x!tt.ptr<f16>, #blocked>
31+
%8 = ttg.convert_layout %7 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
32+
%9 = tt.dot %arg2, %8, %cst : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma>
33+
%10 = tt.dot %arg2, %8, %9 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma>
34+
scf.yield %10 : tensor<128x16xf32, #mma>
35+
}
36+
tt.return %6 : tensor<128x16xf32, #mma>
37+
}
38+
}
39+
40+
// -----
41+
42+
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}>
43+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
44+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
45+
// CHECK-LABEL: tt.func @chained_dots_with_ops_in_between
46+
47+
// Ops between dots
48+
// dot1 -> reduce -> addf %dot1, %reduce1 -> add -> exp2 -> add -> dot2
49+
// We expect to split after the reduce because the result is used twice
50+
51+
// CHECK: scf.for
52+
53+
// CHECK: tt.dot
54+
// CHECK: arith.addf
55+
// CHECK: math.exp2
56+
// CHECK: arith.addf
57+
58+
// CHECK: ttg.async_wait
59+
// CHECK: ttg.local_load
60+
// CHECK: ttg.async_copy_global_to_local
61+
62+
// CHECK: tt.dot
63+
// CHECK: tt.reduce
64+
65+
// CHECK: ttg.async_wait
66+
// CHECK: ttg.local_load
67+
// CHECK: ttg.async_copy_global_to_local
68+
69+
// CHECK: scf.yield
70+
71+
tt.func @chained_dots_with_ops_in_between(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>, %arg2: i32, %arg3: i32) -> tensor<128x16xf32, #mma> {
72+
%c0_i32 = arith.constant 0 : i32
73+
%cst = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma>
74+
%0 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<1x16x!tt.ptr<f16>, #blocked>
75+
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
76+
%2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked>
77+
%3 = tt.broadcast %0 : tensor<1x16x!tt.ptr<f16>, #blocked> -> tensor<64x16x!tt.ptr<f16>, #blocked>
78+
%4 = tt.broadcast %2 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked>
79+
%5 = tt.addptr %3, %4 : tensor<64x16x!tt.ptr<f16>, #blocked>, tensor<64x16xi32, #blocked>
80+
%6 = scf.for %arg5 = %c0_i32 to %arg2 step %arg3 iter_args(%arg6 = %cst) -> (tensor<128x16xf32, #mma>) : i32 {
81+
%7 = tt.load %5 : tensor<64x16x!tt.ptr<f16>, #blocked>
82+
%8 = ttg.convert_layout %7 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
83+
%9 = tt.load %5 : tensor<64x16x!tt.ptr<f16>, #blocked>
84+
%10 = ttg.convert_layout %9 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
85+
%11 = tt.dot %arg1, %8, %cst : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma>
86+
%12 = "tt.reduce"(%11) <{axis = 1 : i32}> ({
87+
^bb0(%arg8: f32, %arg9: f32):
88+
%20 = arith.maxnumf %arg8, %arg9 : f32
89+
tt.reduce.return %20 : f32
90+
}) : (tensor<128x16xf32, #mma>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
91+
%14 = tt.expand_dims %12 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma>
92+
%15 = tt.broadcast %14 : tensor<128x1xf32, #mma> -> tensor<128x16xf32, #mma>
93+
// Split here since %15 is used twice
94+
%16 = arith.addf %11, %15 : tensor<128x16xf32, #mma>
95+
%17 = math.exp2 %15 : tensor<128x16xf32, #mma>
96+
%18 = arith.addf %16, %17 : tensor<128x16xf32, #mma>
97+
%19 = tt.dot %arg1, %10, %18 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma>
98+
scf.yield %19 : tensor<128x16xf32, #mma>
99+
}
100+
tt.return %6#0 : tensor<128x16xf32, #mma>
101+
}
102+
}
103+
104+
// -----
105+
106+
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}>
107+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
108+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
109+
// CHECK-LABEL: tt.func @chained_dots_with_loop_carried_partial_result
110+
111+
// Similar to the previous test but we take the max of the reduce over all iterations (loop carried) so expect a split after the maximum
112+
113+
// CHECK: scf.for
114+
115+
// CHECK: tt.dot
116+
// CHECK: arith.mulf
117+
118+
// CHECK: ttg.async_wait
119+
// CHECK: ttg.local_load
120+
// CHECK: ttg.async_copy_global_to_local
121+
122+
// CHECK: tt.dot
123+
// CHECK: tt.reduce
124+
// CHECK: arith.maxnumf
125+
126+
// CHECK: ttg.async_wait
127+
// CHECK: ttg.local_load
128+
// CHECK: ttg.async_copy_global_to_local
129+
130+
// CHECK: scf.yield
131+
132+
tt.func @chained_dots_with_loop_carried_partial_result(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>, %arg2: i32, %arg3: i32, %arg101: tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>) -> tensor<128x16xf32, #mma> {
133+
%c0_i32 = arith.constant 0 : i32
134+
%cst = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma>
135+
%0 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<1x16x!tt.ptr<f16>, #blocked>
136+
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
137+
%2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked>
138+
%3 = tt.broadcast %0 : tensor<1x16x!tt.ptr<f16>, #blocked> -> tensor<64x16x!tt.ptr<f16>, #blocked>
139+
%4 = tt.broadcast %2 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked>
140+
%5 = tt.addptr %3, %4 : tensor<64x16x!tt.ptr<f16>, #blocked>, tensor<64x16xi32, #blocked>
141+
%6:2 = scf.for %arg4 = %c0_i32 to %arg2 step %arg3 iter_args(%arg5 = %cst, %arg100 = %arg101) -> (tensor<128x16xf32, #mma>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>) : i32 {
142+
%7 = tt.load %5 : tensor<64x16x!tt.ptr<f16>, #blocked>
143+
%8 = ttg.convert_layout %7 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
144+
%9 = tt.load %5 : tensor<64x16x!tt.ptr<f16>, #blocked>
145+
%10 = ttg.convert_layout %9 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
146+
%11 = tt.dot %arg1, %8, %cst : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma>
147+
%12 = "tt.reduce"(%11) <{axis = 1 : i32}> ({
148+
^bb0(%arg6: f32, %arg7: f32):
149+
%21 = arith.maxnumf %arg6, %arg7 : f32
150+
tt.reduce.return %21 : f32
151+
}) : (tensor<128x16xf32, #mma>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
152+
%24 = arith.maxnumf %12, %arg100 :tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
153+
// Split here since %24 is used twice
154+
%13 = tt.expand_dims %24 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma>
155+
%14 = tt.broadcast %13 : tensor<128x1xf32, #mma> -> tensor<128x16xf32, #mma>
156+
%15 = arith.mulf %14, %11 : tensor<128x16xf32, #mma>
157+
%18 = tt.dot %arg1, %10, %15 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma>
158+
scf.yield %18, %24 : tensor<128x16xf32, #mma>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
159+
}
160+
tt.return %6 : tensor<128x16xf32, #mma>
161+
}
162+
}

0 commit comments

Comments
 (0)