Skip to content

Commit 8e5db20

Browse files
authored
[FRONTEND] Add option to disable licm on for and while loops (#7733)
This allows controlling register pressure better. This also adds a tl.condition wrapper to be able to do annotations for while loops. This has a few hacks to get the attribute to propagate all the way to LLVM branch ops. I'm not sure there is a better way to do this downstream. I'll start a conversation upstream to see if this can get fixed.
1 parent f6626cd commit 8e5db20

File tree

15 files changed

+356
-3
lines changed

15 files changed

+356
-3
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,9 @@ inline bool isCanonicalIndex(unsigned index, unsigned freeVarMask) {
623623
// group code isolated from above by invoking this function.
624624
void makeAllWarpGroupsIsolatedFromAbove(Operation *op);
625625

626+
// Set the correct loop annotation on LLVM branch ops.
627+
void fixUpLoopAnnotation(ModuleOp mod);
628+
626629
/// Converts ConverLayoutOp to llvm using padded pattern.
627630
/// This pattern adds unused memory locations after every rows of tensor fastest
628631
/// changing dimension:

include/triton/Dialect/Triton/Transforms/Passes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,4 +90,15 @@ def TritonLoopAwareCSE : Pass<"triton-loop-aware-cse", "mlir::ModuleOp"> {
9090
}];
9191
}
9292

93+
def TritonSCFToCF : Pass</*cli-arg*/"triton-scf-to-cf", /*Op*/"mlir::ModuleOp"> {
94+
let summary = "MLIR's SCF To CF plus some extra attributes propagation.";
95+
let description = [{
96+
This pass uses MLIR's SCF To CF pass as base. Additionally, it propagates
97+
some extra attributes to the converted CFG.
98+
TODO: upstream the llvm loop attribute propagation and remove this pass.
99+
}];
100+
101+
let dependentDialects = [];
102+
}
103+
93104
#endif

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1517,6 +1517,25 @@ void makeAllWarpGroupsIsolatedFromAbove(Operation *op) {
15171517
});
15181518
}
15191519

1520+
// TODO: Is there a better way to do this? This needs to be fixed upstream.
1521+
void fixUpLoopAnnotation(ModuleOp mod) {
1522+
mod->walk([](Operation *op) {
1523+
if (isa<LLVM::BrOp, LLVM::CondBrOp>(op)) {
1524+
if (op->hasAttr("llvm.loop_annotation")) {
1525+
auto loopMD = dyn_cast<LLVM::LoopAnnotationAttr>(
1526+
op->getAttr("llvm.loop_annotation"));
1527+
if (loopMD) {
1528+
if (auto brOp = dyn_cast<LLVM::BrOp>(op)) {
1529+
brOp.setLoopAnnotationAttr(loopMD);
1530+
} else if (auto condBrOp = dyn_cast<LLVM::CondBrOp>(op)) {
1531+
condBrOp.setLoopAnnotationAttr(loopMD);
1532+
}
1533+
}
1534+
}
1535+
}
1536+
});
1537+
}
1538+
15201539
namespace {
15211540

15221541
// Determine which registers are read/written in which iteration of the shmem

lib/Dialect/Triton/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_triton_library(TritonTransforms
1313
RewriteTensorDescriptorToPointer.cpp
1414
ArithTypeConversion.cpp
1515
FunctionTypeConversion.cpp
16+
SCFToCF.cpp
1617

1718
DEPENDS
1819
TritonTransformsIncGen
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
2+
#include "mlir/Dialect/Arith/IR/Arith.h"
3+
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
4+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
5+
#include "mlir/Dialect/SCF/IR/SCF.h"
6+
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
7+
#include "mlir/IR/Builders.h"
8+
#include "mlir/IR/BuiltinOps.h"
9+
#include "mlir/IR/IRMapping.h"
10+
#include "mlir/IR/MLIRContext.h"
11+
#include "mlir/IR/PatternMatch.h"
12+
#include "mlir/Transforms/DialectConversion.h"
13+
#include "mlir/Transforms/Passes.h"
14+
15+
#define GEN_PASS_DEF_TRITONSCFTOCF
16+
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
17+
18+
using namespace mlir;
19+
using namespace mlir::scf;
20+
21+
// While loop lowering patterns forked from MLIR lowering. ForOp already has the
22+
// propagation.
23+
// TODO: Upstream llvm loop attribute propagation and remove this pass.
24+
namespace {
25+
struct SCFToCFPass : public ::impl::TritonSCFToCFBase<SCFToCFPass> {
26+
void runOnOperation() override;
27+
};
28+
29+
/// Create a CFG subgraph for this loop construct. The regions of the loop need
30+
/// not be a single block anymore (for example, if other SCF constructs that
31+
/// they contain have been already converted to CFG), but need to be single-exit
32+
/// from the last block of each region. The operations following the original
33+
/// WhileOp are split into a new continuation block. Both regions of the WhileOp
34+
/// are inlined, and their terminators are rewritten to organize the control
35+
/// flow implementing the loop as follows.
36+
///
37+
/// +---------------------------------+
38+
/// | <code before the WhileOp> |
39+
/// | cf.br ^before(%operands...) |
40+
/// +---------------------------------+
41+
/// |
42+
/// -------| |
43+
/// | v v
44+
/// | +--------------------------------+
45+
/// | | ^before(%bargs...): |
46+
/// | | %vals... = <some payload> |
47+
/// | +--------------------------------+
48+
/// | |
49+
/// | ...
50+
/// | |
51+
/// | +--------------------------------+
52+
/// | | ^before-last:
53+
/// | | %cond = <compute condition> |
54+
/// | | cf.cond_br %cond, |
55+
/// | | ^after(%vals...), ^cont |
56+
/// | +--------------------------------+
57+
/// | | |
58+
/// | | -------------|
59+
/// | v |
60+
/// | +--------------------------------+ |
61+
/// | | ^after(%aargs...): | |
62+
/// | | <body contents> | |
63+
/// | +--------------------------------+ |
64+
/// | | |
65+
/// | ... |
66+
/// | | |
67+
/// | +--------------------------------+ |
68+
/// | | ^after-last: | |
69+
/// | | %yields... = <some payload> | |
70+
/// | | cf.br ^before(%yields...) | |
71+
/// | +--------------------------------+ |
72+
/// | | |
73+
/// |----------- |--------------------
74+
/// v
75+
/// +--------------------------------+
76+
/// | ^cont: |
77+
/// | <code after the WhileOp> |
78+
/// | <%vals from 'before' region |
79+
/// | visible by dominance> |
80+
/// +--------------------------------+
81+
///
82+
/// Values are communicated between ex-regions (the groups of blocks that used
83+
/// to form a region before inlining) through block arguments of their
84+
/// entry blocks, which are visible in all other dominated blocks. Similarly,
85+
/// the results of the WhileOp are defined in the 'before' region, which is
86+
/// required to have a single existing block, and are therefore accessible in
87+
/// the continuation block due to dominance.
88+
struct WhileLowering : public OpRewritePattern<WhileOp> {
89+
using OpRewritePattern<WhileOp>::OpRewritePattern;
90+
91+
LogicalResult matchAndRewrite(WhileOp whileOp,
92+
PatternRewriter &rewriter) const override;
93+
};
94+
95+
/// Optimized version of the above for the case of the "after" region merely
96+
/// forwarding its arguments back to the "before" region (i.e., a "do-while"
97+
/// loop). This avoid inlining the "after" region completely and branches back
98+
/// to the "before" entry instead.
99+
struct DoWhileLowering : public OpRewritePattern<WhileOp> {
100+
using OpRewritePattern<WhileOp>::OpRewritePattern;
101+
102+
LogicalResult matchAndRewrite(WhileOp whileOp,
103+
PatternRewriter &rewriter) const override;
104+
};
105+
} // namespace
106+
107+
LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
108+
PatternRewriter &rewriter) const {
109+
OpBuilder::InsertionGuard guard(rewriter);
110+
Location loc = whileOp.getLoc();
111+
112+
// Split the current block before the WhileOp to create the inlining point.
113+
Block *currentBlock = rewriter.getInsertionBlock();
114+
Block *continuation =
115+
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
116+
117+
// Inline both regions.
118+
Block *after = whileOp.getAfterBody();
119+
Block *before = whileOp.getBeforeBody();
120+
rewriter.inlineRegionBefore(whileOp.getAfter(), continuation);
121+
rewriter.inlineRegionBefore(whileOp.getBefore(), after);
122+
123+
// Branch to the "before" region.
124+
rewriter.setInsertionPointToEnd(currentBlock);
125+
rewriter.create<cf::BranchOp>(loc, before, whileOp.getInits());
126+
127+
// Replace terminators with branches. Assuming bodies are SESE, which holds
128+
// given only the patterns from this file, we only need to look at the last
129+
// block. This should be reconsidered if we allow break/continue in SCF.
130+
rewriter.setInsertionPointToEnd(before);
131+
auto condOp = cast<ConditionOp>(before->getTerminator());
132+
SmallVector<Value> args = llvm::to_vector(condOp.getArgs());
133+
rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
134+
after, condOp.getArgs(),
135+
continuation, ValueRange());
136+
137+
rewriter.setInsertionPointToEnd(after);
138+
auto yieldOp = cast<scf::YieldOp>(after->getTerminator());
139+
auto latch = rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
140+
yieldOp.getResults());
141+
142+
// Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
143+
// llvm.loop_annotation attribute.
144+
SmallVector<NamedAttribute> llvmAttrs;
145+
llvm::copy_if(whileOp->getAttrs(), std::back_inserter(llvmAttrs),
146+
[](auto attr) {
147+
return isa<LLVM::LLVMDialect>(attr.getValue().getDialect());
148+
});
149+
latch->setDiscardableAttrs(llvmAttrs);
150+
// Replace the op with values "yielded" from the "before" region, which are
151+
// visible by dominance.
152+
rewriter.replaceOp(whileOp, args);
153+
154+
return success();
155+
}
156+
157+
LogicalResult
158+
DoWhileLowering::matchAndRewrite(WhileOp whileOp,
159+
PatternRewriter &rewriter) const {
160+
Block &afterBlock = *whileOp.getAfterBody();
161+
if (!llvm::hasSingleElement(afterBlock))
162+
return rewriter.notifyMatchFailure(whileOp,
163+
"do-while simplification applicable "
164+
"only if 'after' region has no payload");
165+
166+
auto yield = dyn_cast<scf::YieldOp>(&afterBlock.front());
167+
if (!yield || yield.getResults() != afterBlock.getArguments())
168+
return rewriter.notifyMatchFailure(whileOp,
169+
"do-while simplification applicable "
170+
"only to forwarding 'after' regions");
171+
172+
// Split the current block before the WhileOp to create the inlining point.
173+
OpBuilder::InsertionGuard guard(rewriter);
174+
Block *currentBlock = rewriter.getInsertionBlock();
175+
Block *continuation =
176+
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
177+
178+
// Only the "before" region should be inlined.
179+
Block *before = whileOp.getBeforeBody();
180+
rewriter.inlineRegionBefore(whileOp.getBefore(), continuation);
181+
182+
// Branch to the "before" region.
183+
rewriter.setInsertionPointToEnd(currentBlock);
184+
auto latch = rewriter.create<cf::BranchOp>(whileOp.getLoc(), before,
185+
whileOp.getInits());
186+
187+
// Loop around the "before" region based on condition.
188+
rewriter.setInsertionPointToEnd(before);
189+
auto condOp = cast<ConditionOp>(before->getTerminator());
190+
SmallVector<Value> args = llvm::to_vector(condOp.getArgs());
191+
rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
192+
before, condOp.getArgs(),
193+
continuation, ValueRange());
194+
195+
// Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
196+
// llvm.loop_annotation attribute.
197+
SmallVector<NamedAttribute> llvmAttrs;
198+
llvm::copy_if(whileOp->getAttrs(), std::back_inserter(llvmAttrs),
199+
[](auto attr) {
200+
return isa<LLVM::LLVMDialect>(attr.getValue().getDialect());
201+
});
202+
latch->setDiscardableAttrs(llvmAttrs);
203+
204+
// Replace the op with values "yielded" from the "before" region, which are
205+
// visible by dominance.
206+
rewriter.replaceOp(whileOp, args);
207+
208+
return success();
209+
}
210+
211+
void SCFToCFPass::runOnOperation() {
212+
RewritePatternSet patterns(&getContext());
213+
// Give our patterns higher benefits so that they get picked up instead of the
214+
// MLIR one.
215+
patterns.add<WhileLowering>(&getContext(), /*benefit=*/3);
216+
patterns.add<DoWhileLowering>(&getContext(), /*benefit=*/4);
217+
mlir::populateSCFToControlFlowConversionPatterns(patterns);
218+
219+
// Configure conversion to lower out SCF operations.
220+
ConversionTarget target(getContext());
221+
target.addIllegalOp<scf::ForallOp, scf::ForOp, scf::IfOp, scf::IndexSwitchOp,
222+
scf::ParallelOp, scf::WhileOp, scf::ExecuteRegionOp>();
223+
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
224+
if (failed(
225+
applyPartialConversion(getOperation(), target, std::move(patterns))))
226+
signalPassFailure();
227+
}
228+
229+
namespace mlir::triton {
230+
std::unique_ptr<mlir::Pass> createTritonSCFToCF() {
231+
return std::make_unique<SCFToCFPass>();
232+
}
233+
} // namespace mlir::triton

python/src/ir.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Bytecode/BytecodeWriter.h"
1010
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
1111
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
12+
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
1213
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1314
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
1415
#include "mlir/Dialect/UB/IR/UBOps.h"
@@ -772,6 +773,18 @@ void init_triton_ir(py::module &&m) {
772773
[](TritonOpBuilder &self, std::string value) -> Attribute {
773774
return self.getBuilder().getStringAttr(value);
774775
})
776+
.def("get_disable_loop_licm_attr",
777+
[](TritonOpBuilder &self) -> Attribute {
778+
auto licmAttr =
779+
LLVM::LoopLICMAttr::get(self.getBuilder().getContext(),
780+
self.getBuilder().getBoolAttr(true),
781+
self.getBuilder().getBoolAttr(true));
782+
mlir::LLVM::LoopAnnotationAttr la =
783+
mlir::LLVM::LoopAnnotationAttr::get(
784+
self.getBuilder().getContext(), {}, {}, {}, {}, {},
785+
licmAttr, {}, {}, {}, {}, {}, {}, {}, {}, {});
786+
return la;
787+
})
775788
// Use arith.ConstantOp to create constants
776789
// Constants
777790
.def("get_int1",

python/src/passes.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ void init_triton_passes_ttgpuir(py::module &&m) {
9696

9797
void init_triton_passes_convert(py::module &&m) {
9898
using namespace mlir;
99+
ADD_PASS_WRAPPER_0("add_triton_scf_to_cf", mlir::triton::createTritonSCFToCF);
99100
ADD_PASS_WRAPPER_0("add_scf_to_cf", createSCFToControlFlowPass);
100101
ADD_PASS_WRAPPER_0("add_cf_to_llvmir", createConvertControlFlowToLLVMPass);
101102
ADD_PASS_WRAPPER_0("add_index_to_llvmir", createConvertIndexToLLVMPass);

python/test/unit/language/test_core.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7379,6 +7379,37 @@ def kernel(ub):
73797379
assert "loop_unroll_factor" not in compiled_kernel.asm["ttir"]
73807380

73817381

7382+
def test_disable_licm():
7383+
7384+
@triton.jit
7385+
def while_no_licm(n):
7386+
i = 0
7387+
while tl.condition(i < n, disable_licm=True):
7388+
i = i + 1
7389+
print("i", i)
7390+
7391+
@triton.jit
7392+
def while_default(n):
7393+
i = 0
7394+
while tl.condition(i < n):
7395+
i = i + 1
7396+
print("i", i)
7397+
7398+
@triton.jit
7399+
def for_no_licm(n):
7400+
for i in tl.range(0, n, disable_licm=True):
7401+
print("i", i)
7402+
7403+
compiled_kernel1 = while_no_licm.warmup(10, grid=(1, ))
7404+
assert "llvm.licm.disable" in compiled_kernel1.asm["llir"]
7405+
7406+
compiled_kernel2 = while_default.warmup(10, grid=(1, ))
7407+
assert "llvm.licm.disable" not in compiled_kernel2.asm["llir"]
7408+
7409+
compiled_kernel3 = for_no_licm.warmup(10, grid=(1, ))
7410+
assert "llvm.licm.disable" in compiled_kernel3.asm["llir"]
7411+
7412+
73827413
@triton.jit(noinline=True)
73837414
def maxnreg_noinline1(X):
73847415
tl.store(X, 0)

0 commit comments

Comments
 (0)