Skip to content

Commit 664ac51

Browse files
authored
[AMD] Sink the 2nd tt.load after local_load's (#4823)
This helps backend to interleave global load and mfma instructions and can reduce global load issue latency.
1 parent 037728b commit 664ac51

File tree

2 files changed

+237
-0
lines changed

2 files changed

+237
-0
lines changed
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
// RUN: triton-opt %s -split-input-file -tritonamdgpu-reorder-instructions | FileCheck %s
2+
3+
// Check the logic of sched-2nd-load optimizations
4+
// The following tile sizes should apply the optimization
5+
// 256x256x128
6+
// 256x256x64
7+
// The following tile sizes should NOT apply the optimization
8+
// 256x64x128
9+
// 256x256x32
10+
// scf.for loop with two dots should not apply the optimization
11+
12+
13+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
14+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}>
15+
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}>
16+
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}>
17+
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}>
18+
#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>
19+
#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>
20+
// Should apply: tile size 256x256x128 with single dot
21+
// CHECK-LABEL: sink_2nd_load_256x256x128
22+
// CHECK: %[[tileA:.*]] = tt.load
23+
// CHECK-NEXT: local_load
24+
// CHECK-NEXT: local_load
25+
// CHECK-NEXT: %[[tileB:.*]] = tt.load
26+
// CHECK-NEXT: tt.dot
27+
// CHECK-NEXT: triton_gpu.local_store %[[tileA]]
28+
// CHECK-NEXT: triton_gpu.local_store %[[tileB]]
29+
module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
30+
tt.func public @sink_2nd_load_256x256x128(%A_ptr: tensor<256x128x!tt.ptr<f16>, #blocked>, %B_ptr: tensor<128x256x!tt.ptr<f16>, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr<f32>, #mma>, %A_LDS: !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) {
31+
%c0 = arith.constant 0 : i32
32+
%c1 = arith.constant 1 : i32
33+
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
34+
%0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 {
35+
%1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0>
36+
%2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x256xf16, #dotOp1>
37+
%3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma>
38+
%4 = tt.load %A_ptr : tensor<256x128x!tt.ptr<f16>, #blocked>
39+
%5 = tt.load %B_ptr : tensor<128x256x!tt.ptr<f16>, #blocked1>
40+
triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>
41+
triton_gpu.local_store %5, %B_LDS : tensor<128x256xf16, #blocked1> -> !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable>
42+
scf.yield %3 : tensor<256x256xf32, #mma>
43+
}
44+
tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr<f32>, #mma>
45+
tt.return
46+
}
47+
}
48+
49+
// Should apply: tile size 256x256x64 with single dot
50+
// CHECK-LABEL: sink_2nd_load_256x256x64
51+
// CHECK: %[[tileA:.*]] = tt.load
52+
// CHECK-NEXT: local_load
53+
// CHECK-NEXT: local_load
54+
// CHECK-NEXT: %[[tileB:.*]] = tt.load
55+
// CHECK-NEXT: tt.dot
56+
// CHECK-NEXT: triton_gpu.local_store %[[tileA]]
57+
// CHECK-NEXT: triton_gpu.local_store %[[tileB]]
58+
module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
59+
tt.func public @sink_2nd_load_256x256x64(%A_ptr: tensor<256x64x!tt.ptr<f16>, #blocked>, %B_ptr: tensor<64x256x!tt.ptr<f16>, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr<f32>, #mma>, %A_LDS: !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) {
60+
%c0 = arith.constant 0 : i32
61+
%c1 = arith.constant 1 : i32
62+
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
63+
%0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 {
64+
%1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #dotOp0>
65+
%2 = triton_gpu.local_load %B_LDS : !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x256xf16, #dotOp1>
66+
%3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma>
67+
%4 = tt.load %A_ptr : tensor<256x64x!tt.ptr<f16>, #blocked>
68+
%5 = tt.load %B_ptr : tensor<64x256x!tt.ptr<f16>, #blocked1>
69+
triton_gpu.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>
70+
triton_gpu.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable>
71+
scf.yield %3 : tensor<256x256xf32, #mma>
72+
}
73+
tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr<f32>, #mma>
74+
tt.return
75+
}
76+
}
77+
78+
// Should NOT apply: tile size 256x64x128 with single dot
79+
// CHECK-LABEL: sink_2nd_load_256x64x128
80+
// CHECK: %[[tileA:.*]] = tt.load
81+
// CHECK-NEXT: %[[tileB:.*]] = tt.load
82+
// CHECK-NEXT: local_load
83+
// CHECK-NEXT: local_load
84+
// CHECK-NEXT: tt.dot
85+
// CHECK-NEXT: triton_gpu.local_store %[[tileA]]
86+
// CHECK-NEXT: triton_gpu.local_store %[[tileB]]
87+
module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
88+
tt.func public @sink_2nd_load_256x64x128(%A_ptr: tensor<256x128x!tt.ptr<f16>, #blocked>, %B_ptr: tensor<128x64x!tt.ptr<f16>, #blocked1>, %C_ptr: tensor<256x64x!tt.ptr<f32>, #mma>, %A_LDS: !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable>) {
89+
%c0 = arith.constant 0 : i32
90+
%c1 = arith.constant 1 : i32
91+
%cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma>
92+
%0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x64xf32, #mma>) : i32 {
93+
%1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0>
94+
%2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x64xf16, #dotOp1>
95+
%3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x64xf16, #dotOp1> -> tensor<256x64xf32, #mma>
96+
%4 = tt.load %A_ptr : tensor<256x128x!tt.ptr<f16>, #blocked>
97+
%5 = tt.load %B_ptr : tensor<128x64x!tt.ptr<f16>, #blocked1>
98+
triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>
99+
triton_gpu.local_store %5, %B_LDS : tensor<128x64xf16, #blocked1> -> !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable>
100+
scf.yield %3 : tensor<256x64xf32, #mma>
101+
}
102+
tt.store %C_ptr, %0#0: tensor<256x64x!tt.ptr<f32>, #mma>
103+
tt.return
104+
}
105+
}
106+
107+
// Should NOT apply: tile size 256x256x32 with single dot
108+
// CHECK-LABEL: sink_2nd_load_256x256x32
109+
// CHECK: %[[tileA:.*]] = tt.load
110+
// CHECK-NEXT: %[[tileB:.*]] = tt.load
111+
// CHECK-NEXT: local_load
112+
// CHECK-NEXT: local_load
113+
// CHECK-NEXT: tt.dot
114+
// CHECK-NEXT: triton_gpu.local_store %[[tileA]]
115+
// CHECK-NEXT: triton_gpu.local_store %[[tileB]]
116+
module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
117+
tt.func public @sink_2nd_load_256x256x32(%A_ptr: tensor<256x32x!tt.ptr<f16>, #blocked>, %B_ptr: tensor<32x256x!tt.ptr<f16>, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr<f32>, #mma>, %A_LDS: !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) {
118+
%c0 = arith.constant 0 : i32
119+
%c1 = arith.constant 1 : i32
120+
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
121+
%0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 {
122+
%1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x32xf16, #dotOp0>
123+
%2 = triton_gpu.local_load %B_LDS : !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x256xf16, #dotOp1>
124+
%3 = tt.dot %1, %2, %arg1 : tensor<256x32xf16, #dotOp0> * tensor<32x256xf16, #dotOp1> -> tensor<256x256xf32, #mma>
125+
%4 = tt.load %A_ptr : tensor<256x32x!tt.ptr<f16>, #blocked>
126+
%5 = tt.load %B_ptr : tensor<32x256x!tt.ptr<f16>, #blocked1>
127+
triton_gpu.local_store %4, %A_LDS : tensor<256x32xf16, #blocked> -> !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable>
128+
triton_gpu.local_store %5, %B_LDS : tensor<32x256xf16, #blocked1> -> !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable>
129+
scf.yield %3 : tensor<256x256xf32, #mma>
130+
}
131+
tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr<f32>, #mma>
132+
tt.return
133+
}
134+
}
135+
136+
// Should NOT apply: tile size 128x128x128 with two dots
137+
// CHECK-LABEL: sink_2nd_load_128x128x128_two_dot
138+
// CHECK: %[[tileA:.*]] = tt.load
139+
// CHECK-NEXT: %[[tileB:.*]] = tt.load
140+
// CHECK-NEXT: local_load
141+
// CHECK-NEXT: local_load
142+
// CHECK-NEXT: tt.dot
143+
// CHECK-NEXT: tt.dot
144+
// CHECK-NEXT: triton_gpu.local_store %[[tileA]]
145+
// CHECK-NEXT: triton_gpu.local_store %[[tileB]]
146+
module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
147+
tt.func public @sink_2nd_load_128x128x128_two_dot(%A_ptr: tensor<128x128x!tt.ptr<f16>, #blocked>, %B_ptr: tensor<128x128x!tt.ptr<f16>, #blocked1>, %C_ptr: tensor<128x128x!tt.ptr<f32>, #mma>, %A_LDS: !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable>) {
148+
%c0 = arith.constant 0 : i32
149+
%c1 = arith.constant 1 : i32
150+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
151+
%0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<128x128xf32, #mma>) : i32 {
152+
%1 = triton_gpu.local_load %A_LDS : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp0>
153+
%2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp1>
154+
%3 = tt.dot %1, %2, %arg1 : tensor<128x128xf16, #dotOp0> * tensor<128x128xf16, #dotOp1> -> tensor<128x128xf32, #mma>
155+
%6 = tt.dot %1, %2, %3 : tensor<128x128xf16, #dotOp0> * tensor<128x128xf16, #dotOp1> -> tensor<128x128xf32, #mma>
156+
%4 = tt.load %A_ptr : tensor<128x128x!tt.ptr<f16>, #blocked>
157+
%5 = tt.load %B_ptr : tensor<128x128x!tt.ptr<f16>, #blocked1>
158+
triton_gpu.local_store %4, %A_LDS : tensor<128x128xf16, #blocked> -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable>
159+
triton_gpu.local_store %5, %B_LDS : tensor<128x128xf16, #blocked1> -> !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable>
160+
scf.yield %6 : tensor<128x128xf32, #mma>
161+
}
162+
tt.store %C_ptr, %0#0: tensor<128x128x!tt.ptr<f32>, #mma>
163+
tt.return
164+
}
165+
}

third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,78 @@ class TritonAMDGPUReorderInstructionsPass
247247
dfgop->moveBefore(block, block->begin());
248248
}
249249
}
250+
251+
/**
252+
* Sched-load optimization for matmul kernels with large tile sizes
253+
* The basic idea of sched-load optimization is to sink the 2nd tt.load
254+
* after local_load so that global_load instructions can be interleaved with
255+
* mfma's. This can help hide the issue latency of global_load instructions
256+
* and improve performance on MI300X.
257+
*
258+
* It's assumed that the IR before this optimization has the following
259+
* structure:
260+
* ```mlir
261+
* scf.for ..
262+
* {
263+
* tileA = tt.load a_ptr
264+
* tileB = tt.load b_ptr
265+
* opA = local_load bufferA
266+
* opB = local_load bufferB
267+
* res = tt.dot opA, opB
268+
* local_store tileA, bufferA
269+
* local_store tileB, bufferB
270+
* }
271+
* ```
272+
* After this optimization, the IR is transformed to
273+
* ```mlir
274+
* scf.for ..
275+
* {
276+
* tileA = tt.load a_ptr
277+
* opA = local_load bufferA
278+
* opB = local_load bufferB
279+
* tileB = tt.load b_ptr <-- 2nd tt.load is sinked here
280+
* res = tt.dot opA, opB
281+
* local_store tileA, bufferA
282+
* local_store tileB, bufferB
283+
* }
284+
* ```
285+
* For now, we don't have a perfect hueristic about when should this
286+
* optimization be applied. Therefore, we implement a simple hueristic that
287+
* this is applied when the tile size of A and B are large enough, i.e.
288+
* nonKDim >= 128 and kDim >= 64. And also this is only applied for typical
289+
* matmul kernels, i.e. only two tt.load's and one dotOp inside the loop. We
290+
* are experimenting how to better control instruction scheduling and enable
291+
* such optimizations.
292+
*/
293+
m.walk([&](scf::ForOp forOp) -> void {
294+
SetVector<Operation *> loadOps;
295+
triton::DotOp dotOp;
296+
int nDotOps = 0;
297+
for (Operation &op : forOp) {
298+
if (auto loadOp = dyn_cast<triton::LoadOp>(&op))
299+
loadOps.insert(loadOp);
300+
if (auto curOp = dyn_cast<triton::DotOp>(&op)) {
301+
nDotOps++;
302+
dotOp = curOp;
303+
}
304+
}
305+
// Only apply the optimization when there are 2 load's and 1 dot in the
306+
// loop
307+
if (loadOps.size() != 2 || nDotOps != 1)
308+
return;
309+
// Only apply the optimization when tile size is large enough
310+
// 1. nonKDim >= 128
311+
// 2. kDim >= 64
312+
auto ldAOp = dyn_cast<triton::LoadOp>(loadOps[0]);
313+
auto tileAShape = cast<RankedTensorType>(ldAOp.getType()).getShape();
314+
auto ldBOp = dyn_cast<triton::LoadOp>(loadOps[1]);
315+
auto tileBShape = cast<RankedTensorType>(ldBOp.getType()).getShape();
316+
if (!(tileAShape[0] >= 128 && tileAShape[1] >= 64 &&
317+
tileBShape[1] >= 128))
318+
return;
319+
// move ldBOp right before tt.dot
320+
loadOps[1]->moveBefore(dotOp);
321+
});
250322
}
251323
};
252324

0 commit comments

Comments
 (0)