Skip to content

Commit d2b8659

Browse files
authored
[LoopUnroll] Do not pipeline epilog loops generated by loop unrolling (#5027)
The epilog loop created by the loop unroller may not be run if the main unrolled loop covers all original loop iterations, thus pipelining it non-speculatively may not be beneficial. It can also cause some correctness issue when combined with the downstream PTXAS optimizer.
1 parent 3c296ab commit d2b8659

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

lib/Dialect/Triton/Transforms/LoopUnroll.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,22 @@
2222

2323
namespace mlir::triton {
2424

25-
static const char *loopUnrollFactorAttrName = "tt.loop_unroll_factor";
26-
2725
namespace {
2826

2927
class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {
3028

3129
int getUnrollFactorOrDefault(scf::ForOp forOp) {
3230
// Use the attribute attached to the loop if it exists otherwise set the
3331
// factor to 1 to suppress the unrolling.
34-
if (auto factor = forOp->getAttrOfType<IntegerAttr>(
35-
mlir::triton::loopUnrollFactorAttrName))
32+
if (auto factor =
33+
forOp->getAttrOfType<IntegerAttr>(loopUnrollFactorAttrName))
3634
return factor.getInt();
3735
return 1;
3836
}
3937

38+
const char *loopUnrollFactorAttrName = "tt.loop_unroll_factor";
39+
const char *pipelineStagesAttrName = "tt.num_stages";
40+
4041
public:
4142
LoopUnrollPass() = default;
4243
LoopUnrollPass(const LoopUnrollPass &) {}
@@ -49,11 +50,18 @@ class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {
4950
loops.push_back(forOp);
5051
});
5152

53+
auto ctx = getOperation()->getContext();
5254
for (auto loop : loops) {
5355
auto unrollFactor = getUnrollFactorOrDefault(loop);
54-
loop->removeAttr(mlir::triton::loopUnrollFactorAttrName);
56+
loop->removeAttr(loopUnrollFactorAttrName);
5557
LDBG("Unrolling loop by " << unrollFactor << " times\n" << loop);
56-
(void)loopUnrollByFactor(loop, unrollFactor);
58+
auto resultLoops = loopUnrollByFactor(loop, unrollFactor);
59+
// Do not pipeline the epilog loop.
60+
if (succeeded(resultLoops) && resultLoops->epilogueLoopOp) {
61+
(*resultLoops->epilogueLoopOp)
62+
->setAttr(pipelineStagesAttrName,
63+
mlir::IntegerAttr::get(IntegerType::get(ctx, 32), 1));
64+
}
5765
}
5866
}
5967
};

test/Triton/loop-unroll.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ tt.func @add_kernel_unroll(%arg0: tensor<256x!tt.ptr<f32>>, %arg1: i32) {
1313
// CHECK: scf.for
1414
// CHECK: tt.load
1515
// CHECK-NOT: tt.load
16+
// CHECK: tt.num_stages = 1 : i32
1617
%2:2 = scf.for %arg3 = %c1_i32 to %arg1 step %c1_i32 iter_args(%arg4 = %1, %arg5 = %arg0) -> (tensor<256xf32>, tensor<256x!tt.ptr<f32>>) : i32 {
1718
%3 = tt.load %arg5 : tensor<256x!tt.ptr<f32>>
1819
%4 = arith.addf %arg4, %3 : tensor<256xf32>

0 commit comments

Comments
 (0)