Skip to content

Commit 16b25e1

Browse files
authored
[BACKEND] combineRedundantWaitOps should not combine across loops/branches (#7593)
`combineRedundantWaitOps` did skip over branches/loops, so if we end up with something like: ```mlir ttg.async_wait scf.for .... scf.yield ttg.async_wait ``` we merge the async_waits in the prologue and epilogue because we do not find a `ttg.commit_group` in between. This PR stops the forward search if we encounter a branch/loop. I can also walk through all successor blocks if we think this is worth the effort. This problem was not triggered before because the `ttg.async_wait` was scheduled in the same stage as its user(s) so we ended up with no `ttg.async_wait` in the prologue or there was another prefetch after it in the prologue. Since triton-lang/triton#7458 we might place the `ttg.async_wait` in the previous stage compared to its user(s) so we might end up with the problematic IR.
1 parent fdd694d commit 16b25e1

File tree

3 files changed

+39
-10
lines changed

3 files changed

+39
-10
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,10 @@ void mlir::triton::combineRedundantWaitOps(
469469
SmallVector<Value> depTokens = waitOp.getOperands();
470470
unsigned minWaitNumber = waitOp.getNum();
471471
Operation *next = waitOp->getNextNode();
472-
while (next && !isa<ttg::AsyncCommitGroupOp>(next)) {
472+
// Stop if we reach the end of the block or if there is another commit group
473+
// or a branching op (forOp, ifOp, whileOp) in between the waits
474+
while (next &&
475+
!isa<ttg::AsyncCommitGroupOp, RegionBranchOpInterface>(next)) {
473476
if (auto nextWait = dyn_cast<ttg::AsyncWaitOp>(next)) {
474477
waitGroup.push_back(nextWait);
475478
minWaitNumber = std::min(minWaitNumber, nextWait.getNum());
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline="num_stages=3 use_async_copy=1 use_pingpong=1" | FileCheck %s
2+
3+
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}>
4+
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
5+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 32 : i32} {
6+
// CHECK-LABEL: tt.func @simple_pipelined_load
7+
// We expect one ttg.async_wait in the epilogue, one in the loop and one in the prologue
8+
// CHECK: ttg.async_wait
9+
// CHECK-NOT: ttg.async_wait
10+
// CHECK: scf.for
11+
// CHECK: ttg.async_wait
12+
// CHECK-NOT: ttg.async_wait
13+
// CHECK: scf.yield
14+
// CHECK: ttg.async_wait
15+
// CHECK-NOT: ttg.async_wait
16+
tt.func @simple_pipelined_load(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>, %arg3: i32, %arg4: i32) -> tensor<128x16xf32, #mma> {
17+
%c0_i32 = arith.constant 0 : i32
18+
%cst = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma>
19+
%0 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<1x16x!tt.ptr<f16>, #blocked>
20+
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
21+
%2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked>
22+
%3 = tt.broadcast %0 : tensor<1x16x!tt.ptr<f16>, #blocked> -> tensor<64x16x!tt.ptr<f16>, #blocked>
23+
%4 = tt.broadcast %2 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked>
24+
%5 = tt.addptr %3, %4 : tensor<64x16x!tt.ptr<f16>, #blocked>, tensor<64x16xi32, #blocked>
25+
%6 = scf.for %arg6 = %c0_i32 to %arg3 step %arg4 iter_args(%arg5 = %cst) -> (tensor<128x16xf32, #mma>) : i32 {
26+
%7 = tt.load %5 : tensor<64x16x!tt.ptr<f16>, #blocked>
27+
%8 = ttg.convert_layout %7 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
28+
%9 = tt.dot %arg2, %8, %cst : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma>
29+
scf.yield %9 : tensor<128x16xf32, #mma>
30+
}
31+
tt.return %6 : tensor<128x16xf32, #mma>
32+
}
33+
}

third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -765,16 +765,9 @@ struct PipelinePass : impl::TritonAMDGPUStreamPipelineBase<PipelinePass> {
765765
useAsyncCopy, waitAtTail);
766766
}
767767

768-
if (useAsyncCopy && numStages != 3) {
768+
if (useAsyncCopy) {
769769
llvm::SmallSetVector<ttg::AsyncWaitOp, 8> waitOps;
770-
moduleOp.walk([&](ttg::AsyncWaitOp waitOp) {
771-
if (auto maybeForOp = dyn_cast<scf::ForOp>(waitOp->getParentOp()))
772-
// FIXME: There's potential bug in combinRedundantWaitOps(), it
773-
// generate incorrect IR order when numStages==3.
774-
if (tt::getNumStagesOrDefault(maybeForOp, numStages) == 3)
775-
return;
776-
waitOps.insert(waitOp);
777-
});
770+
moduleOp.walk([&](ttg::AsyncWaitOp waitOp) { waitOps.insert(waitOp); });
778771
tt::combineRedundantWaitOps(waitOps);
779772
}
780773
}

0 commit comments

Comments
 (0)