Skip to content

Commit b316d88

Browse files
[PIPELINER] Reintroduce epilogue peeling (#6962)
Reintroduce "[PIPELINE] Peel single epilogue iteration after loop expansion #6893" Fixed the issue that was exposing IMA by using poison values as buffer indices used by speculatively executed async copy.
1 parent c00f747 commit b316d88

File tree

20 files changed

+433
-32
lines changed

20 files changed

+433
-32
lines changed

bin/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ target_link_libraries(triton-opt PRIVATE
1212
${triton_libs}
1313
# tests
1414
TritonTestAnalysis
15+
TritonTestDialect
1516
TritonAMDGPUTestAnalysis
1617
# MLIR core
1718
MLIROptLib
@@ -31,6 +32,7 @@ target_link_libraries(triton-reduce PRIVATE
3132
${triton_libs}
3233
# tests
3334
TritonTestAnalysis
35+
TritonTestDialect
3436
TritonAMDGPUTestAnalysis
3537
# MLIR core
3638
MLIRReduceLib
@@ -49,6 +51,7 @@ target_link_libraries(triton-lsp PRIVATE
4951
${triton_libs}
5052
# tests
5153
TritonTestAnalysis
54+
TritonTestDialect
5255
TritonAMDGPUTestAnalysis
5356
# MLIR core
5457
MLIRLspServerLib
@@ -85,5 +88,6 @@ target_link_libraries(triton-tensor-layout PRIVATE
8588
${conversion_libs}
8689
${dialect_libs}
8790
TritonTestAnalysis
91+
TritonTestDialect
8892
TritonAMDGPUTestAnalysis
8993
)

bin/RegisterTritonDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ void registerTestAllocationPass();
3737
void registerTestMembarPass();
3838
void registerTestAMDGPUMembarPass();
3939
void registerTestTritonAMDGPURangeAnalysis();
40+
void registerTestLoopPeelingPass();
4041
} // namespace test
4142
} // namespace mlir
4243

@@ -49,6 +50,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
4950
mlir::test::registerTestAlignmentPass();
5051
mlir::test::registerTestAllocationPass();
5152
mlir::test::registerTestMembarPass();
53+
mlir::test::registerTestLoopPeelingPass();
5254
mlir::test::registerTestAMDGPUMembarPass();
5355
mlir::test::registerTestTritonAMDGPURangeAnalysis();
5456
mlir::triton::registerConvertTritonToTritonGPUPass();

include/triton/Dialect/Triton/IR/Utility.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ template <typename T> auto seq(T start, T end, T step) {
177177
Value getPredMask(RewriterBase &rewriter, Type typeLike, Value currentMask,
178178
Value pred);
179179

180+
// Get the value of the induction variable at the end of the loop.
181+
Value getLastInductionValue(OpBuilder &b, scf::ForOp loop);
182+
180183
MakeTensorPtrOp getMakeTensorPtrOp(Value v);
181184

182185
} // namespace triton
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_LOOP_PEELING_H_
2+
#define TRITON_DIALECT_TRITON_TRANSFORMS_LOOP_PEELING_H_
3+
4+
#include "mlir/Dialect/SCF/IR/SCF.h"
5+
6+
namespace mlir {
7+
namespace triton {
8+
9+
// Peel the single last iteration of the loop.
10+
void peelLoopEpilogue(
11+
scf::ForOp forOp,
12+
function_ref<Operation *(RewriterBase &, Operation *, bool)>
13+
processPeeledOp = nullptr);
14+
15+
} // namespace triton
16+
} // namespace mlir
17+
18+
#endif // TRITON_DIALECT_TRITON_TRANSFORMS_LOOP_PEELING_H_

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,24 @@ def TTG_PredicateStageOp: TTG_Op<"predicate_stage",
337337
let assemblyFormat = "$iv `,` $ub `,` $step `maxStage` $maxStage `stage` $stage attr-dict `:` type($iv) `->` type($result)";
338338
}
339339

340+
def TTG_MaskOp: TTG_Op<"mask",
341+
[SingleBlock]> {
342+
let summary = "mask op for pipelining";
343+
let arguments = (ins I1:$pred);
344+
let results = (outs Variadic<AnyType>:$result);
345+
let regions = (region SizedRegion<1>:$region);
346+
let builders = [
347+
OpBuilder<(ins "Value":$pred)>,
348+
];
349+
}
350+
351+
def TTG_MaskReturnOp: TTG_Op<"mask.return",
352+
[HasParent<"MaskOp">, Pure, Terminator, ReturnLike]> {
353+
let summary = "terminator for mask operator";
354+
let arguments = (ins Variadic<AnyType>:$result);
355+
let assemblyFormat = "$result attr-dict `:` type($result)";
356+
}
357+
340358
def TTG_Fp4ToFpOp : TTG_Op<"fp4_to_fp", [Pure]> {
341359
let summary = "Upcast fp4 (e2m1) to fp";
342360

lib/Dialect/Triton/IR/Utility.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "triton/Dialect/Triton/IR/Utility.h"
22
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
3+
#include "mlir/Dialect/SCF/IR/SCF.h"
34
#include "triton/Dialect/Triton/IR/Dialect.h"
45

56
using namespace mlir;
@@ -90,3 +91,15 @@ tt::MakeTensorPtrOp tt::getMakeTensorPtrOp(Value v) {
9091
}
9192
llvm_unreachable("Unable to getMakeTensorPtr()");
9293
}
94+
95+
Value tt::getLastInductionValue(OpBuilder &b, scf::ForOp loop) {
96+
Location loc = loop.getLoc();
97+
// (ub - lb -1) // step * step + lb
98+
Value diff =
99+
b.create<arith::SubIOp>(loc, loop.getUpperBound(), loop.getLowerBound());
100+
diff = b.create<arith::SubIOp>(
101+
loc, diff, b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(1)));
102+
Value ceilStep = b.create<arith::MulIOp>(
103+
loc, b.create<arith::DivSIOp>(loc, diff, loop.getStep()), loop.getStep());
104+
return b.create<arith::AddIOp>(loc, ceilStep, loop.getLowerBound());
105+
}

lib/Dialect/Triton/Transforms/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_triton_library(TritonTransforms
66
Combine.cpp
77
LoopAwareCSE.cpp
88
LoopInvariantCodeMotion.cpp
9+
LoopPeeling.cpp
910
LoopUnroll.cpp
1011
ReorderBroadcast.cpp
1112
RewriteTensorPointer.cpp
@@ -20,5 +21,7 @@ add_triton_library(TritonTransforms
2021
LINK_LIBS PUBLIC
2122
MLIRPass
2223
MLIRTransformUtils
24+
MLIRTransforms
25+
MLIRSCFToControlFlow
2326
TritonIR
2427
)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#include "triton/Dialect/Triton/Transforms/LoopPeeling.h"
2+
#include "mlir/Dialect/SCF/IR/SCF.h"
3+
#include "mlir/Pass/Pass.h"
4+
#include "triton/Dialect/Triton/IR/Utility.h"
5+
6+
using namespace mlir;
7+
8+
namespace mlir {
9+
namespace triton {
10+
11+
void peelLoopEpilogue(
12+
scf::ForOp forOp,
13+
function_ref<Operation *(RewriterBase &, Operation *, bool)>
14+
processPeeledOp) {
15+
SmallVector<Operation *> loopBodyOps;
16+
IRRewriter rewriter(forOp);
17+
Location loc = forOp.getLoc();
18+
Type type = forOp.getStep().getType();
19+
20+
// Fetch loop bounds and step
21+
Value lowerBound = forOp.getLowerBound();
22+
Value upperBound = forOp.getUpperBound();
23+
Value step = forOp.getStep();
24+
Value newUpperBound = rewriter.create<arith::SubIOp>(loc, upperBound, step);
25+
26+
rewriter.setInsertionPointAfter(forOp);
27+
Value lastIV = getLastInductionValue(rewriter, forOp);
28+
29+
auto cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
30+
lowerBound, upperBound);
31+
32+
// Create an if op to execute the peeled iteration
33+
IRMapping map;
34+
map.map(forOp.getRegionIterArgs(), forOp.getResults());
35+
map.map(forOp.getInductionVar(), lastIV);
36+
auto ifOp = rewriter.create<scf::IfOp>(loc, forOp.getResultTypes(), cond,
37+
/*hasElse=*/true);
38+
ifOp.getThenRegion().front().erase();
39+
forOp.getBodyRegion().cloneInto(&ifOp.getThenRegion(), map);
40+
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
41+
rewriter.create<scf::YieldOp>(loc, forOp.getResults());
42+
43+
forOp->replaceUsesWithIf(ifOp, [&](OpOperand &operand) {
44+
return !ifOp->isAncestor(operand.getOwner());
45+
});
46+
47+
forOp.getUpperBoundMutable().assign(newUpperBound);
48+
49+
if (processPeeledOp) {
50+
for (auto &op :
51+
llvm::make_early_inc_range(forOp.getBody()->without_terminator())) {
52+
Operation *newOp = processPeeledOp(rewriter, &op, /*isEpilogue=*/false);
53+
if (newOp && newOp != &op) {
54+
op.replaceAllUsesWith(newOp);
55+
}
56+
}
57+
for (auto &op : llvm::make_early_inc_range(
58+
ifOp.getThenRegion().front().without_terminator())) {
59+
Operation *newOp = processPeeledOp(rewriter, &op, /*isEpilogue=*/true);
60+
if (newOp && newOp != &op) {
61+
op.replaceAllUsesWith(newOp);
62+
}
63+
}
64+
}
65+
}
66+
67+
} // namespace triton
68+
} // namespace mlir

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ add_triton_library(TritonGPUTransforms
4444
MLIRTransformUtils
4545
TritonAnalysis
4646
TritonIR
47+
TritonTransforms
4748
TritonGPUIR
4849
TritonNvidiaGPUIR
4950
TritonToTritonGPU

lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,10 +275,10 @@ class AssignMMALatencies {
275275
// place the wait right before the loads.
276276

277277
if (hasSyncDots(forOp)) {
278-
// Skip pipelining MMA in the loops where sync dots are used. This is
279-
// dirty heuristic for performance drops in kernels where we would
280-
// rather want to have last iteration peeled instead of having a full
281-
// iteration of masked operations only to execute single wait.
278+
// Skip pipelining MMA in the loops where sync dots are used. This
279+
// is a dirty heuristic for performance drops in kernels where we
280+
// would rather want to have last iteration peeled instead of having a
281+
// full iteration of masked operations only to execute single wait.
282282
continue;
283283
}
284284
auto pipeHelper = ttng::MMAv5PipelineableOperandsHelper(

0 commit comments

Comments
 (0)