Skip to content

Commit cc25374

Browse files
authored
[AMD][Pipeliner] Improve clustering and add prefetch (#4881)
This commit improves pipeliner op clustering so that we can avoid relying complicated and fragile reordering step later. In order to do this, we formalized stages a bit and improved documentation accordingly. Also this commit adds an extra experimental stage to buffer in registers before compute, which is a part of a series of commits to improve scheduling perf.
1 parent 462de12 commit cc25374

File tree

10 files changed

+548
-646
lines changed

10 files changed

+548
-646
lines changed

test/TritonGPU/amd/amd-reorder-instructions.mlir

Lines changed: 0 additions & 345 deletions
Large diffs are not rendered by default.

test/TritonGPU/amd/amd-sched-2nd-load.mlir

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
3535
%c1 = arith.constant 1 : i32
3636
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
3737
%0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 {
38+
%4 = tt.load %A_ptr : tensor<256x128x!tt.ptr<f16>, #blocked>
3839
%1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0>
40+
%5 = tt.load %B_ptr : tensor<128x256x!tt.ptr<f16>, #blocked1>
3941
%2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x256xf16, #dotOp1>
4042
%3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma>
41-
%4 = tt.load %A_ptr : tensor<256x128x!tt.ptr<f16>, #blocked>
42-
%5 = tt.load %B_ptr : tensor<128x256x!tt.ptr<f16>, #blocked1>
4343
triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>
4444
triton_gpu.local_store %5, %B_LDS : tensor<128x256xf16, #blocked1> -> !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable>
4545
scf.yield %3 : tensor<256x256xf32, #mma>
@@ -64,11 +64,11 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
6464
%c1 = arith.constant 1 : i32
6565
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
6666
%0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 {
67+
%4 = tt.load %A_ptr : tensor<256x64x!tt.ptr<f16>, #blocked>
6768
%1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #dotOp0>
69+
%5 = tt.load %B_ptr : tensor<64x256x!tt.ptr<f16>, #blocked1>
6870
%2 = triton_gpu.local_load %B_LDS : !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x256xf16, #dotOp1>
6971
%3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma>
70-
%4 = tt.load %A_ptr : tensor<256x64x!tt.ptr<f16>, #blocked>
71-
%5 = tt.load %B_ptr : tensor<64x256x!tt.ptr<f16>, #blocked1>
7272
triton_gpu.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>
7373
triton_gpu.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable>
7474
scf.yield %3 : tensor<256x256xf32, #mma>
@@ -81,8 +81,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
8181
// Should NOT apply: tile size 256x64x128 with single dot
8282
// CHECK-LABEL: sink_2nd_load_256x64x128
8383
// CHECK: %[[tileA:.*]] = tt.load
84-
// CHECK-NEXT: %[[tileB:.*]] = tt.load
8584
// CHECK-NEXT: local_load
85+
// CHECK-NEXT: %[[tileB:.*]] = tt.load
8686
// CHECK-NEXT: local_load
8787
// CHECK-NEXT: tt.dot
8888
// CHECK-NEXT: triton_gpu.local_store %[[tileA]]
@@ -93,11 +93,11 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
9393
%c1 = arith.constant 1 : i32
9494
%cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma>
9595
%0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x64xf32, #mma>) : i32 {
96+
%4 = tt.load %A_ptr : tensor<256x128x!tt.ptr<f16>, #blocked>
9697
%1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0>
98+
%5 = tt.load %B_ptr : tensor<128x64x!tt.ptr<f16>, #blocked1>
9799
%2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x64xf16, #dotOp1>
98100
%3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x64xf16, #dotOp1> -> tensor<256x64xf32, #mma>
99-
%4 = tt.load %A_ptr : tensor<256x128x!tt.ptr<f16>, #blocked>
100-
%5 = tt.load %B_ptr : tensor<128x64x!tt.ptr<f16>, #blocked1>
101101
triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>
102102
triton_gpu.local_store %5, %B_LDS : tensor<128x64xf16, #blocked1> -> !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable>
103103
scf.yield %3 : tensor<256x64xf32, #mma>
@@ -110,8 +110,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
110110
// Should NOT apply: tile size 256x256x32 with single dot
111111
// CHECK-LABEL: sink_2nd_load_256x256x32
112112
// CHECK: %[[tileA:.*]] = tt.load
113-
// CHECK-NEXT: %[[tileB:.*]] = tt.load
114113
// CHECK-NEXT: local_load
114+
// CHECK-NEXT: %[[tileB:.*]] = tt.load
115115
// CHECK-NEXT: local_load
116116
// CHECK-NEXT: tt.dot
117117
// CHECK-NEXT: triton_gpu.local_store %[[tileA]]
@@ -122,11 +122,11 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
122122
%c1 = arith.constant 1 : i32
123123
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
124124
%0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 {
125+
%4 = tt.load %A_ptr : tensor<256x32x!tt.ptr<f16>, #blocked>
125126
%1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x32xf16, #dotOp0>
127+
%5 = tt.load %B_ptr : tensor<32x256x!tt.ptr<f16>, #blocked1>
126128
%2 = triton_gpu.local_load %B_LDS : !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x256xf16, #dotOp1>
127129
%3 = tt.dot %1, %2, %arg1 : tensor<256x32xf16, #dotOp0> * tensor<32x256xf16, #dotOp1> -> tensor<256x256xf32, #mma>
128-
%4 = tt.load %A_ptr : tensor<256x32x!tt.ptr<f16>, #blocked>
129-
%5 = tt.load %B_ptr : tensor<32x256x!tt.ptr<f16>, #blocked1>
130130
triton_gpu.local_store %4, %A_LDS : tensor<256x32xf16, #blocked> -> !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable>
131131
triton_gpu.local_store %5, %B_LDS : tensor<32x256xf16, #blocked1> -> !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable>
132132
scf.yield %3 : tensor<256x256xf32, #mma>
@@ -142,8 +142,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
142142
// Should NOT apply: the 2nd load has a user before the dot
143143
// CHECK-LABEL: sink_2nd_load_128x128x128_user_before_dot
144144
// CHECK: %[[tileA:.*]] = tt.load
145-
// CHECK-NEXT: %[[tileB:.*]] = tt.load
146145
// CHECK-NEXT: local_load
146+
// CHECK-NEXT: %[[tileB:.*]] = tt.load
147147
// CHECK-NEXT: local_load
148148
// CHECK-NEXT: tt.store
149149
// CHECK-NEXT: tt.dot
@@ -154,10 +154,10 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
154154
%c1 = arith.constant 1 : i32
155155
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
156156
%0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<128x128xf32, #mma>) : i32 {
157-
%1 = triton_gpu.local_load %A_LDS : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp0>
158-
%2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp1>
159157
%4 = tt.load %A_ptr : tensor<128x128x!tt.ptr<f16>, #blocked>
158+
%1 = triton_gpu.local_load %A_LDS : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp0>
160159
%5 = tt.load %B_ptr : tensor<128x128x!tt.ptr<i64>, #blocked>
160+
%2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp1>
161161
tt.store %B_ptr, %5 : tensor<128x128x!tt.ptr<i64>, #blocked>
162162
%3 = tt.dot %1, %2, %arg1 : tensor<128x128xf16, #dotOp0> * tensor<128x128xf16, #dotOp1> -> tensor<128x128xf32, #mma>
163163
triton_gpu.local_store %4, %A_LDS : tensor<128x128xf16, #blocked> -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable>
@@ -174,12 +174,12 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
174174
// Category 3: two dots in the for loop. Make sure the optimization is not applied
175175
// should NOT apply: two dots
176176
// CHECK-LABEL: sink_2nd_load_256x256x64_two_dot
177-
// CHECK: triton_gpu.local_load
177+
// CHECK: tt.load
178+
// CHECK-NEXT: tt.load
179+
// CHECK-NEXT: triton_gpu.local_load
178180
// CHECK-NEXT: triton_gpu.local_load
179181
// CHECK-NEXT: tt.dot
180182
// CHECK-NEXT: tt.dot
181-
// CHECK-NEXT: tt.load
182-
// CHECK-NEXT: tt.load
183183
// CHECK-NEXT: triton_gpu.local_store
184184
// CHECK-NEXT: triton_gpu.local_store
185185
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}>

test/TritonGPU/loop-pipeline-hip.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1
3535
%16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr<f16>, #blocked>, tensor<64x16xi32, #blocked>
3636
// CHECK: triton_gpu.local_store
3737
// CHECK: scf.for
38+
// CHECK: tt.load
3839
// CHECK: tt.dot
3940
// CHECK: tt.dot
40-
// CHECK: tt.load
4141
// CHECK: triton_gpu.local_store
4242
// CHECK: scf.yield
4343
%17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 {
@@ -165,9 +165,9 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1
165165
// CHECK-LABEL: tt.func public @add_barrier_kernel
166166
// CHECK: tt.load
167167
// CHECK: scf.for
168+
// CHECK: tt.load
168169
// CHECK: gpu.barrier
169170
// CHECK: tt.store
170-
// CHECK: tt.load
171171
// CHECK: scf.yield
172172
// CHECK: gpu.barrier
173173
// CHECK: tt.store

0 commit comments

Comments
 (0)