Skip to content

Commit c8a31a0

Browse files
authored
[AMD] Prevent wrong reordering of scf operations (#5203)
The pass was reordering scf.if operations without checking the extra dependencies coming from the region. For now just prevent this case although this part of the code might still be fragile.
1 parent 133109d commit c8a31a0

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

test/TritonGPU/amd/amd-reorder-instructions.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,3 +499,40 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
499499
tt.return
500500
}
501501
}
502+
503+
504+
// -----
505+
506+
#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
507+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} {
508+
// CHECK-LABEL: dont_hoist_scf_ops
509+
// Make sure we don't hoist scf ops above its dependencies.
510+
tt.func public @dont_hoist_scf_ops(%init: tensor<256x128xf32, #mfma>,
511+
%base: tensor<256x128x!tt.ptr<f16>, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>,
512+
%p1: tensor<128x128x!tt.ptr<f16>, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>, %i1: i1) -> (tensor<256x128xf32, #mfma>) {
513+
%c0_i32 = arith.constant 0 : i32
514+
%c1_i32 = arith.constant 1 : i32
515+
%c4_i32 = arith.constant 4 : i32
516+
%cst = arith.constant 1.44269502 : f32
517+
%c128_i32 = arith.constant 128 : i32
518+
// CHECK: scf.for
519+
%54 = scf.for %arg21 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg = %init) -> (tensor<256x128xf32, #mfma>) : i32 {
520+
// CHECK: arith.addi
521+
%f = arith.addi %arg21, %c128_i32 : i32
522+
// CHECK: scf.if
523+
// CHECK: tt.load
524+
%p0 = scf.if %i1 -> tensor<256x128x!tt.ptr<f16>, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>{
525+
%t = tt.splat %f : i32 -> tensor<256x128xi32>
526+
%padd = tt.addptr %base, %t : tensor<256x128x!tt.ptr<f16>, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>, tensor<256x128xi32>
527+
scf.yield %padd : tensor<256x128x!tt.ptr<f16>, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
528+
} else {
529+
scf.yield %base : tensor<256x128x!tt.ptr<f16>, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
530+
}
531+
%l = tt.load %p0 : tensor<256x128x!tt.ptr<f16>, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
532+
%r = tt.load %p1 : tensor<128x128x!tt.ptr<f16>, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>
533+
%acc = tt.dot %l, %r, %arg : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma>
534+
scf.yield %acc : tensor<256x128xf32, #mfma>
535+
}
536+
tt.return %54 : tensor<256x128xf32, #mfma>
537+
}
538+
}

third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ static void scheduleGlobalLoadLocalStore(triton::FuncOp funcOp) {
227227
// Gather use-def chain in block.
228228
Block *block = op->getBlock();
229229
bool leadsToLoad = false;
230+
bool dontReorder = false;
230231
SetVector<Operation *> backwardSet;
231232

232233
BackwardSliceOptions options;
@@ -236,6 +237,13 @@ static void scheduleGlobalLoadLocalStore(triton::FuncOp funcOp) {
236237
Block *defBlock = defOp->getBlock();
237238
if (!block->findAncestorOpInBlock(*defOp))
238239
return false;
240+
// Don't hoist control flow as we don't track backtraces of ops within
241+
// their regions.
242+
if (isa<scf::IfOp, scf::ForOp, scf::WhileOp>(defOp)) {
243+
dontReorder = true;
244+
return false;
245+
}
246+
239247
// Check for a `load` dependent path.
240248
leadsToLoad |= isa<triton::LoadOp>(defOp);
241249
// Only move ops residing in the same block.
@@ -244,6 +252,9 @@ static void scheduleGlobalLoadLocalStore(triton::FuncOp funcOp) {
244252
mlir::getBackwardSlice(op, &backwardSet, options);
245253
backwardSet.insert(op);
246254

255+
// If we found ops in the slice we don't want to hoist.
256+
if (dontReorder)
257+
continue;
247258
// Don't move a local_store if its source is a load from
248259
// the same iteration.
249260
if (isa<ttg::LocalStoreOp>(op) && leadsToLoad)

0 commit comments

Comments
 (0)