Skip to content

Commit 2c57e20

Browse files
authored
[TritonGPU] Add Loop Aware CSE pass (#6809)
I thought LLVM would be able to do this, but apparently not. This recursively analyzes the computations of loop iteration arguments to check if two iter args always have the same value and replaces one with the other. This reduces the register usage of pipelined/warp specialized loops by crushing the number of phase and index arguments. Sometimes up to 10 registers can be saved by this, which can be significant in high-pressure areas. This replaces uses of the normal CSE pass in make_ttgir and adds canonicalize+loop_aware_cse after the pipeliner. TODO: - [x] write tests - [x] manually check GB200 pytests
1 parent 915cc70 commit 2c57e20

File tree

6 files changed

+245
-3
lines changed

6 files changed

+245
-3
lines changed

include/triton/Dialect/Triton/Transforms/Passes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,15 @@ def TritonLoopInvariantCodeMotion : Pass</*cli-arg*/"triton-licm", /*Op*/"mlir::
7979
let dependentDialects = ["mlir::triton::TritonDialect"];
8080
}
8181

82+
def TritonLoopAwareCSE : Pass<"triton-loop-aware-cse", "mlir::ModuleOp"> {
83+
let summary = "CSE within loop bodies";
84+
85+
let description = [{
86+
The `triton-loop-aware-cse` pass performs recursive common subexpression
87+
elimination within loop bodies. Unlike regular CSE, which is a single-pass
88+
greedy algorithm, this pass can recursively eliminate loop iteration
89+
arguments and subcomputations that always have the same value.
90+
}];
91+
}
92+
8293
#endif

lib/Dialect/Triton/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_public_tablegen_target(TritonCombineIncGen)
44

55
add_triton_library(TritonTransforms
66
Combine.cpp
7+
LoopAwareCSE.cpp
78
LoopInvariantCodeMotion.cpp
89
LoopUnroll.cpp
910
ReorderBroadcast.cpp
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
#include "mlir/Dialect/SCF/IR/SCF.h"
2+
#include "mlir/IR/Dominance.h"
3+
#include "mlir/Pass/Pass.h"
4+
#include "mlir/Transforms/CSE.h"
5+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
6+
#include "llvm/ADT/EquivalenceClasses.h"
7+
8+
using namespace mlir;
9+
10+
namespace mlir::triton {
11+
#define GEN_PASS_DEF_TRITONLOOPAWARECSE
12+
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
13+
} // namespace mlir::triton
14+
15+
namespace {
16+
class ValueEquivalence {
17+
public:
18+
std::optional<bool> getKnownEquivalence(Value a, Value b) {
19+
if (auto it = equalValues.find(normalizeKey(a, b)); it != equalValues.end())
20+
return it->second;
21+
return std::nullopt;
22+
}
23+
void setKnownEquivalence(Value a, Value b, bool eq) {
24+
equalValues.insert_or_assign(normalizeKey(a, b), eq);
25+
}
26+
27+
private:
28+
// Commutatively query the equivalence of two values by sorting the key by
29+
// pointer value.
30+
std::pair<Value, Value> normalizeKey(Value a, Value b) {
31+
if ((uintptr_t)a.getAsOpaquePointer() < (uintptr_t)b.getAsOpaquePointer())
32+
return {a, b};
33+
return {b, a};
34+
}
35+
36+
DenseMap<std::pair<Value, Value>, bool> equalValues;
37+
};
38+
39+
struct LoopCSEDriver {
40+
LoopCSEDriver(scf::ForOp loop) : loop(loop) {}
41+
42+
bool areIterArgsEqual(int i, int j);
43+
bool areEqualInLoop(Value a, Value b);
44+
45+
scf::ForOp loop;
46+
ValueEquivalence equalValues;
47+
};
48+
} // namespace
49+
50+
bool LoopCSEDriver::areIterArgsEqual(int i, int j) {
51+
if (i == j)
52+
return true;
53+
if (loop.getInitArgs()[i] != loop.getInitArgs()[j])
54+
return false;
55+
BlockArgument aArg = loop.getRegionIterArg(i);
56+
BlockArgument bArg = loop.getRegionIterArg(j);
57+
// First, assume the arguments are equal. This is how recursion is broken.
58+
equalValues.setKnownEquivalence(aArg, bArg, true);
59+
bool result =
60+
areEqualInLoop(loop.getYieldedValues()[i], loop.getYieldedValues()[j]);
61+
// Now update the equivalence based on the actual result.
62+
equalValues.setKnownEquivalence(aArg, bArg, result);
63+
return result;
64+
}
65+
66+
bool LoopCSEDriver::areEqualInLoop(Value a, Value b) {
67+
// Check trivial case.
68+
if (a == b)
69+
return true;
70+
if (a.getType() != b.getType())
71+
return false;
72+
73+
Block *aBlock = a.getParentBlock();
74+
Block *bBlock = b.getParentBlock();
75+
// Values from outside the loop must have been equal.
76+
if (aBlock != loop.getBody() || bBlock != loop.getBody()) {
77+
return false;
78+
}
79+
// Both must be block arguments or not.
80+
if (isa<BlockArgument>(a) != isa<BlockArgument>(b))
81+
return false;
82+
// Both must be the inductor var or not.
83+
if (a == loop.getInductionVar() || b == loop.getInductionVar())
84+
return false;
85+
86+
if (std::optional<bool> eq = equalValues.getKnownEquivalence(a, b))
87+
return *eq;
88+
89+
if (auto aArg = dyn_cast<BlockArgument>(a)) {
90+
auto bArg = cast<BlockArgument>(b);
91+
bool result =
92+
areIterArgsEqual(aArg.getArgNumber() - 1, bArg.getArgNumber() - 1);
93+
equalValues.setKnownEquivalence(a, b, result);
94+
return result;
95+
}
96+
97+
Operation *aDef = a.getDefiningOp();
98+
Operation *bDef = b.getDefiningOp();
99+
// For it to be known that the operation results have the same value, they
100+
// must be side effect free.
101+
if (!isMemoryEffectFree(aDef) || !isMemoryEffectFree(bDef))
102+
return false;
103+
// Don't bother with operations with regions.
104+
if (aDef->getNumRegions() || bDef->getNumRegions())
105+
return false;
106+
107+
bool result = OperationEquivalence::isEquivalentTo(
108+
aDef, bDef,
109+
[&](Value a, Value b) { return success(areEqualInLoop(a, b)); },
110+
[&](Value a, Value b) { equalValues.setKnownEquivalence(a, b, true); },
111+
OperationEquivalence::IgnoreLocations);
112+
equalValues.setKnownEquivalence(a, b, result);
113+
return result;
114+
}
115+
116+
static void loopCSE(scf::ForOp loop) {
117+
int numIterArgs = loop.getNumRegionIterArgs();
118+
// Group equivalent iter args together.
119+
llvm::EquivalenceClasses<int> equivalentArgs;
120+
LoopCSEDriver driver(loop);
121+
for (int i = 0; i != numIterArgs; ++i) {
122+
for (int j = i + 1; j != numIterArgs; ++j) {
123+
if (driver.areIterArgsEqual(i, j))
124+
equivalentArgs.unionSets(i, j);
125+
}
126+
}
127+
128+
// For each equivalence class, replace all other args in the class with one.
129+
for (auto it = equivalentArgs.begin(), end = equivalentArgs.end(); it != end;
130+
++it) {
131+
if (!(*it)->isLeader())
132+
continue;
133+
SmallVector<int> eqArgs;
134+
for (auto mIt = equivalentArgs.member_begin(**it);
135+
mIt != equivalentArgs.member_end(); ++mIt)
136+
eqArgs.push_back(*mIt);
137+
assert(eqArgs.size() > 1);
138+
// Sort the indices so the pass is deterministic.
139+
llvm::sort(eqArgs);
140+
BlockArgument unique = loop.getRegionIterArg(eqArgs.front());
141+
Value uniqueResult = loop.getResult(eqArgs.front());
142+
for (int j : llvm::drop_begin(eqArgs)) {
143+
BlockArgument other = loop.getRegionIterArg(j);
144+
other.replaceAllUsesWith(unique);
145+
// Short-circuit the value. The canonicalizer will clean this up. Leftover
146+
// subcomputations can now be removed by normal CSE.
147+
(*loop.getYieldedValuesMutable())[j].set(other);
148+
loop.getResult(j).replaceAllUsesWith(uniqueResult);
149+
}
150+
}
151+
}
152+
153+
namespace {
154+
struct LoopAwareCSE
155+
: public triton::impl::TritonLoopAwareCSEBase<LoopAwareCSE> {
156+
using TritonLoopAwareCSEBase::TritonLoopAwareCSEBase;
157+
158+
void runOnOperation() override {
159+
// LoopAwareCSE doesn't recursively CSE ops outside of loops, so run CSE
160+
// first to make sure values from outside loops that are equivalent are made
161+
// pointer equal.
162+
IRRewriter rewriter(&getContext());
163+
auto &domInfo = getAnalysis<DominanceInfo>();
164+
eliminateCommonSubExpressions(rewriter, domInfo, getOperation());
165+
166+
// CSE region iter args within loop bodies.
167+
getOperation().walk(loopCSE);
168+
169+
// Now that equivalent iter args have been made pointer equal, run CSE again
170+
// to clean up the loop body.
171+
eliminateCommonSubExpressions(rewriter, domInfo, getOperation());
172+
173+
// Run the `scf.for` canonicalizer to clean up the loops (short-circuited
174+
// values, unused results, etc.).
175+
RewritePatternSet patterns(&getContext());
176+
scf::ForOp::getCanonicalizationPatterns(patterns, &getContext());
177+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
178+
return signalPassFailure();
179+
}
180+
};
181+
} // namespace

python/src/passes.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ void init_triton_passes_ttir(py::module &&m) {
4444
createTritonRewriteTensorDescriptorToPointer);
4545
ADD_PASS_WRAPPER_0("add_loop_unroll", createTritonLoopUnroll);
4646
ADD_PASS_WRAPPER_0("add_triton_licm", createTritonLoopInvariantCodeMotion);
47+
ADD_PASS_WRAPPER_0("add_loop_aware_cse", createTritonLoopAwareCSE);
4748
ADD_PASS_OPTION_WRAPPER_4("add_convert_to_ttgpuir",
4849
createConvertTritonToTritonGPU, const std::string &,
4950
int, int, int);

test/Triton/loop_cse.mlir

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// RUN: triton-opt %s -triton-loop-aware-cse -allow-unregistered-dialect | FileCheck %s
2+
3+
// CHECK-LABEL: @loop_buffer_phase_args
4+
tt.func @loop_buffer_phase_args(%arg0: i32) {
5+
%c2_i32 = arith.constant 2 : i32
6+
%c128_i32 = arith.constant 128 : i32
7+
%c0_i32 = arith.constant 0 : i32
8+
%c1_i32 = arith.constant 1 : i32
9+
// CHECK: [[LOOP_RES:%.*]]:3 = scf.for {{.*}} iter_args
10+
// CHECK-SAME: [[M2_INDEX:%arg[0-9]+]] = %c0_i32
11+
// CHECK-SAME: [[M2_PHASE:%arg[0-9]+]] = %c0_i32
12+
// CHECK-SAME: [[M1_PHASE:%arg[0-9]+]] = %c0_i32
13+
%0:10 = scf.for %arg1 = %c0_i32 to %arg0 step %c128_i32 iter_args(%arg2 = %c0_i32, %arg3 = %c0_i32, %arg4 = %c0_i32, %arg5 = %c0_i32, %arg6 = %c0_i32, %arg7 = %c0_i32, %arg8 = %c0_i32, %arg9 = %c0_i32, %arg10 = %c0_i32, %arg11 = %c0_i32) -> (i32, i32, i32, i32, i32, i32, i32, i32, i32, i32) : i32 {
14+
%1 = arith.subi %arg0, %c128_i32 : i32
15+
%2 = arith.cmpi slt, %arg1, %1 : i32
16+
// CHECK: [[M1_PHASE_INCR:%.*]] = arith.xori [[M1_PHASE]], %c1_i32
17+
%3 = arith.xori %arg7, %c1_i32 : i32
18+
// CHECK: "index_phase_use"([[M2_INDEX]], [[M2_PHASE]], [[M1_PHASE_INCR]], [[M1_PHASE]])
19+
"index_phase_use"(%arg4, %arg5, %3, %arg8) : (i32, i32, i32, i32) -> ()
20+
%4 = arith.addi %arg4, %c1_i32 : i32
21+
%5 = arith.xori %arg5, %c1_i32 : i32
22+
%6 = arith.cmpi eq, %4, %c2_i32 : i32
23+
// CHECK: [[M2_INDEX_INCR:%.*]] = arith.select %{{.*}}, %c0_i32
24+
// CHECK-NEXT: [[M2_PHASE_INCR:%.*]] = arith.select %{{.*}}, %{{.*}}, [[M2_PHASE]]
25+
// CHECK-NOT: arith.select
26+
%7 = arith.select %6, %c0_i32, %4 : i32
27+
%8 = arith.select %6, %5, %arg5 : i32
28+
%9 = arith.xori %arg8, %c1_i32 : i32
29+
%10 = arith.xori %arg11, %c1_i32 : i32
30+
%11 = arith.xori %arg6, %c1_i32 : i32
31+
%12 = arith.addi %arg2, %c1_i32 : i32
32+
%13 = arith.xori %arg3, %c1_i32 : i32
33+
%14 = arith.cmpi eq, %12, %c2_i32 : i32
34+
%15 = arith.select %14, %c0_i32, %12 : i32
35+
%16 = arith.select %14, %13, %arg3 : i32
36+
// CHECK: "index_phase_use"([[M2_INDEX_INCR]], [[M2_PHASE_INCR]], [[M1_PHASE_INCR]],
37+
"index_phase_use"(%15, %16, %11, %2) : (i32, i32, i32, i1) -> ()
38+
%17 = arith.xori %arg10, %c1_i32 : i32
39+
// CHECK: "index_phase_use"([[M1_PHASE_INCR]], [[M1_PHASE]])
40+
"index_phase_use"(%17, %arg11) : (i32, i32) -> ()
41+
%18 = arith.xori %arg9, %c1_i32 : i32
42+
// CHECK: "index_phase_use"([[M1_PHASE_INCR]], [[M1_PHASE]])
43+
"index_phase_use"(%17, %arg11) : (i32, i32) -> ()
44+
scf.yield %15, %16, %7, %8, %11, %3, %9, %18, %17, %10 : i32, i32, i32, i32, i32, i32, i32, i32, i32, i32
45+
}
46+
tt.return
47+
}

third_party/nvidia/backend/compiler.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def make_ttgir(mod, metadata, opt, capability):
246246
passes.ttgpuir.add_remove_layout_conversions(pm)
247247
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
248248
nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm)
249-
passes.common.add_cse(pm)
249+
passes.ttir.add_loop_aware_cse(pm)
250250
if capability // 10 in [8, 9]:
251251
passes.ttgpuir.add_fuse_nested_loops(pm)
252252
passes.common.add_canonicalizer(pm)
@@ -265,9 +265,10 @@ def make_ttgir(mod, metadata, opt, capability):
265265
passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
266266
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
267267
nvidia.passes.ttnvgpuir.add_remove_tmem_tokens(pm)
268-
passes.common.add_canonicalizer(pm)
269268
else:
270269
passes.ttir.add_triton_licm(pm)
270+
passes.common.add_canonicalizer(pm)
271+
passes.ttir.add_loop_aware_cse(pm)
271272
passes.ttgpuir.add_prefetch(pm)
272273
passes.ttgpuir.add_WGMMAPrefetch(pm)
273274
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
@@ -277,7 +278,7 @@ def make_ttgir(mod, metadata, opt, capability):
277278
nvidia.passes.ttnvgpuir.add_interleave_tmem(pm)
278279
passes.ttgpuir.add_reduce_data_duplication(pm)
279280
passes.ttgpuir.add_reorder_instructions(pm)
280-
passes.common.add_cse(pm)
281+
passes.ttir.add_loop_aware_cse(pm)
281282
passes.common.add_symbol_dce(pm)
282283
if capability // 10 >= 9:
283284
nvidia.passes.ttnvgpuir.add_tma_lowering(pm)

0 commit comments

Comments
 (0)