Skip to content

Commit be04dd2

Browse files
Mogballmakslevental
authored andcommitted
[Pipeliner] Fix crash in rewriting TMA descriptor updates (triton-lang#5843)
Lots of our code assumes that `scf.if` has a non-empty else region, but sometimes it can be empty, which typically happens due to one of the `scf.if` canonicalizers. Just make sure to create `scf.if` with non-empty regions. This was split off from triton-lang#5726 since others were hitting the crash.
1 parent 03ff5d2 commit be04dd2

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -703,14 +703,14 @@ scf::IfOp replaceIfOpWithNewSignature(
703703
// Create a new loop before the existing one, with the extra operands.
704704
auto resultTypes = llvm::to_vector<4>(ifOp.getResults().getTypes());
705705
resultTypes.append(newResultTypes.begin(), newResultTypes.end());
706-
scf::IfOp newIf = rewriter.create<scf::IfOp>(
707-
ifOp.getLoc(), resultTypes, ifOp.getCondition(), /*withElse=*/true);
706+
scf::IfOp newIf = rewriter.create<scf::IfOp>(ifOp.getLoc(), resultTypes,
707+
ifOp.getCondition());
708708
newIf->setAttrs(ifOp->getAttrs());
709709

710-
rewriter.inlineBlockBefore(ifOp.thenBlock(), newIf.thenBlock(),
711-
newIf.thenBlock()->begin());
712-
rewriter.inlineBlockBefore(ifOp.elseBlock(), newIf.elseBlock(),
713-
newIf.elseBlock()->begin());
710+
newIf.getThenRegion().takeBody(ifOp.getThenRegion());
711+
newIf.getElseRegion().takeBody(ifOp.getElseRegion());
712+
scf::IfOp::ensureTerminator(newIf.getThenRegion(), rewriter, ifOp.getLoc());
713+
scf::IfOp::ensureTerminator(newIf.getElseRegion(), rewriter, ifOp.getLoc());
714714

715715
for (auto it : llvm::zip(ifOp.getResults(),
716716
newIf.getResults().take_front(ifOp.getNumResults())))

test/TritonGPU/matmul-loop-pipeline.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,33 @@ tt.func public @scalar_load(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: i32, %arg3:
4747
}
4848

4949
}
50+
51+
// -----
52+
53+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>
54+
55+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90"} {
56+
57+
// CHECK-LABEL: @make_tensor_desc_epilogue
58+
tt.func public @make_tensor_desc_epilogue(%arg0: i32, %arg1: !tt.ptr<f32>, %arg2: i32) {
59+
%c0_i32 = arith.constant 0 : i32
60+
%c1_i32 = arith.constant 1 : i32
61+
%c1_i64 = arith.constant 1 : i64
62+
// CHECK: scf.for
63+
scf.for %arg3 = %c0_i32 to %arg0 step %c1_i32 : i32 {
64+
%1 = tt.splat %arg1 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : !tt.ptr<f32> -> tensor<128x256x!tt.ptr<f32>, #blocked>
65+
%2 = tt.load %1 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : tensor<128x256x!tt.ptr<f32>, #blocked>
66+
%3 = arith.addf %2, %2 {loop.cluster = 5 : i32, loop.stage = 2 : i32} : tensor<128x256xf32, #blocked>
67+
%4 = arith.cmpi eq, %arg3, %c1_i32 {loop.cluster = 5 : i32, loop.stage = 2 : i32} : i32
68+
// CHECK: scf.if
69+
scf.if %4 {
70+
// CHECK-NOT: tt.make_tensor_descriptor
71+
// CHECK: tt.experimental_tensormap_create
72+
// CHECK-NEXT: tt.experimental_tensormap_fenceproxy_acquire
73+
%5 = tt.make_tensor_descriptor %arg1, [%arg2, %arg2], [%c1_i64, %c1_i64] : <f32>, <tensor<128x256xf32>>
74+
} {loop.cluster = 5 : i32, loop.stage = 2 : i32}
75+
} {tt.num_stages = 3 : i32}
76+
tt.return
77+
}
78+
79+
}

0 commit comments

Comments
 (0)