Skip to content

Commit 2a6bc7e

Browse files
leonling-llanmyachev
authored andcommitted
[AMD] Fix lowerLoops: only erase load ops which are converted (#8737)
This change addresses the issue that when there is a LoadOp and AddfOp between 2 dots in a loop, this LoadOp is not streamable in AMDGPUPipeline Pass. This case would make compile crash for erasing LoadOp which still have uses. The solution is to replace `loadToInfo` with `loadToStreamOps`, so that only erase LoadOps that are converted to Stream Ops.
1 parent 9fb66fb commit 2a6bc7e

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

test/TritonGPU/amd/amd-pipeline-chained-dots.mlir

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,57 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
160160
tt.return %6 : tensor<128x16xf32, #mma>
161161
}
162162
}
163+
164+
// -----
165+
166+
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>
167+
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [8, 1], instrShape = [16, 16, 16], isTransposed = true}>
168+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
169+
// CHECK-LABEL: tt.func @chained_dots_with_load_bias_in_between
170+
171+
// Similar to the previous test but load bias tensor bewteen 2 dots
172+
// We expect the unstreamable load can be kept after pipelining
173+
174+
// CHECK: scf.for
175+
// CHECK: tt.dot
176+
// CHECK: ttg.async_copy_global_to_local
177+
// CHECK: tt.dot
178+
// CHECK: ttg.async_wait
179+
// CHECK: ttg.local_load
180+
// CHECK: tt.load
181+
// CHECK: scf.yield
182+
183+
tt.func @chained_dots_with_load_bias_in_between(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg2: i64 {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: i32) -> tensor<256x64xf32, #mma> {
184+
%c0_i32 = arith.constant 0 : i32
185+
%c1_i32 = arith.constant 1 : i32
186+
%c64_i32 = arith.constant 64 : i32
187+
%cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma>
188+
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
189+
%1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked>
190+
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<64x64x!tt.ptr<f16>, #blocked>
191+
%3 = tt.broadcast %1 : tensor<64x1xi32, #blocked> -> tensor<64x64xi32, #blocked>
192+
%4 = tt.addptr %2, %3 : tensor<64x64x!tt.ptr<f16>, #blocked>, tensor<64x64xi32, #blocked>
193+
%5 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
194+
%6 = tt.splat %arg3 : !tt.ptr<f16> -> tensor<256x64x!tt.ptr<f16>, #blocked>
195+
%7 = scf.for %arg5 = %c0_i32 to %arg4 step %c1_i32 iter_args(%arg6 = %cst) -> (tensor<256x64xf32, #mma>) : i32 {
196+
%8 = tt.load %4 : tensor<64x64x!tt.ptr<f16>, #blocked>
197+
%9 = ttg.convert_layout %8 : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
198+
%10 = tt.dot %arg1, %9, %cst : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x64xf32, #mma>
199+
%11 = arith.muli %arg5, %c64_i32 : i32
200+
%12 = tt.splat %11 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
201+
%13 = arith.addi %12, %5 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
202+
%14 = tt.expand_dims %13 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
203+
%15 = tt.broadcast %14 : tensor<1x64xi32, #blocked> -> tensor<256x64xi32, #blocked>
204+
%bias_ptr = tt.addptr %6, %15 : tensor<256x64x!tt.ptr<f16>, #blocked>, tensor<256x64xi32, #blocked>
205+
%bias = tt.load %bias_ptr : tensor<256x64x!tt.ptr<f16>, #blocked>
206+
%bias_mma = ttg.convert_layout %bias : tensor<256x64xf16, #blocked> -> tensor<256x64xf16, #mma>
207+
%bias_f32 = arith.extf %bias_mma : tensor<256x64xf16, #mma> to tensor<256x64xf32, #mma>
208+
%dot_bias = arith.addf %10, %bias_f32 : tensor<256x64xf32, #mma>
209+
%21 = arith.truncf %dot_bias : tensor<256x64xf32, #mma> to tensor<256x64xf16, #mma>
210+
%22 = ttg.convert_layout %21 : tensor<256x64xf16, #mma> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
211+
%23 = tt.dot %22, %9, %arg6 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x64xf32, #mma>
212+
scf.yield %23 : tensor<256x64xf32, #mma>
213+
}
214+
tt.return %7 : tensor<256x64xf32, #mma>
215+
}
216+
}

third_party/amd/lib/TritonAMDGPUTransforms/LowerLoops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ void updateSchedule(scf::ForOp &forOp, const LoadToInfoMap &loadToInfo,
718718
useAsyncCopy, axisInfoAnalysis);
719719
scheduleStreamOps(loadToStreamOps, schedule, clusters);
720720

721-
for (auto [l, _] : loadToInfo) {
721+
for (auto [l, _] : loadToStreamOps) {
722722
schedule.erase(l);
723723
l->erase();
724724
}

0 commit comments

Comments
 (0)