Skip to content

Commit 2ec711b

Browse files
authored
[Blackwell] Subtile TMEM stores and improve TMEM interleaving (#6808)
This PR introduces a similar trick as `split(tmem_load)` to `tmem_store(join)` by propagating the store through the join. This also splits out the TMEM load sinking into a separate pass and makes it more powerful. * It has more nuanced alias analysis. This allows interleaving RMW of values in TMEM * It will try to iteratively sink multiple pure user ops Having a separate pass also makes it compose better without remove-layout-conversions. This reduces register pressure a lot in certain cases, because ptxas is quite ineffective at sinking LDTM ops.
1 parent 99b5e29 commit 2ec711b

File tree

12 files changed

+553
-151
lines changed

12 files changed

+553
-151
lines changed

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ include "triton/Dialect/Triton/IR/TritonTypes.td"
2929
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
3030
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
3131
include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"
32+
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
3233
include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td"
3334
include "triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td"
3435
include "mlir/IR/OpBase.td"
@@ -584,6 +585,12 @@ def TTNG_TMEMStoreOp : TTNG_Op<"tmem_store"> {
584585
);
585586
let results = (outs Optional<TTG_AsyncToken>:$token);
586587

588+
let builders = [
589+
OpBuilder<(ins "Value":$dst, "Value":$src, "Value":$pred), [{
590+
build($_builder, $_state, Type(), dst, Value(), src, pred);
591+
}]>
592+
];
593+
587594
let assemblyFormat = [{
588595
$src `,` $dst `` custom<Token>($dep, type($token)) `,` $pred
589596
attr-dict `:` type($src) `->` qualified(type($dst))

include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ std::unique_ptr<Pass> createTritonNvidiaGPUOptimizeDescriptorEncodingPass();
6464

6565
std::unique_ptr<Pass> createTritonNvidiaGPUOptimizeTMemLayoutsPass();
6666

67+
std::unique_ptr<Pass> createTritonNvidiaGPUInterleaveTMemPass();
68+
6769
/// Generate the code for registering passes.
6870
#define GEN_PASS_REGISTRATION
6971
#define GEN_PASS_DECL_TRITONNVIDIAGPULEGALIZETMALAYOUTS

include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,16 @@ def TritonNvidiaGPUOptimizeTMemLayoutsPass : Pass<"triton-nvidia-optimize-tmem-l
143143
"mlir::triton::TritonDialect"];
144144
}
145145

146+
def TritonNvidiaGPUInterleaveTMemPass : Pass<"triton-nvidia-interleave-tmem", "mlir::ModuleOp"> {
147+
let summary = "Interleave TMEM loads/stores.";
148+
149+
let description = [{
150+
The `triton-nvidia-interleave-tmem` pass attempts to sink TMEM loads and
151+
hoist TMEM stores, and potentially interleave them, to reduce register
152+
pressure.
153+
}];
154+
}
155+
146156
def TritonNvidiaGPURemoveTMEMTokensPass : Pass<"triton-nvidia-gpu-remove-tmem-tokens", "mlir::ModuleOp"> {
147157
let summary = "remove TMEM tokens";
148158

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,9 +445,17 @@ MemDescTransOp::inferReturnTypes(MLIRContext *context,
445445
// MemDescReshapeOp
446446

447447
LogicalResult MemDescReshapeOp::verify() {
448-
// Infer the dst layout from the source and verify that it is equivalent.
449448
MemDescType dstType = getResult().getType();
450449
MemDescType srcType = getSrc().getType();
450+
if (product(dstType.getShape()) != product(srcType.getShape())) {
451+
return emitError(
452+
"number of src and dst elements of reshape must be the same");
453+
}
454+
if (dstType.getElementType() != srcType.getElementType()) {
455+
return emitError("result element type must match src element type");
456+
}
457+
458+
// Infer the dst layout from the source and verify that it is equivalent.
451459
auto srcEncoding = srcType.getEncoding();
452460
Attribute inferedDstEncoding;
453461

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,15 @@ static void printToken(OpAsmPrinter &p, Operation *op, Value dep, Type token) {
244244
void TCGen5MMAOp::getEffects(
245245
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
246246
&effects) {
247+
// The op reads the accumulator if `useD` is not known to be false.
248+
APInt useD;
249+
if (!matchPattern(getUseD(), m_ConstantInt(&useD)) || !useD.isZero()) {
250+
effects.emplace_back(MemoryEffects::Read::get(), &getDMutable(),
251+
TensorMemory::get());
252+
}
247253
effects.emplace_back(MemoryEffects::Write::get(), &getDMutable(),
248254
TensorMemory::get());
255+
249256
if (isa<SharedMemorySpaceAttr>(getA().getType().getMemorySpace())) {
250257
effects.emplace_back(MemoryEffects::Read::get(), &getAMutable(),
251258
SharedMemory::get());
@@ -296,8 +303,15 @@ void TCGen5MMAOp::build(OpBuilder &builder, OperationState &state, Type token,
296303
void TCGen5MMAScaledOp::getEffects(
297304
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
298305
&effects) {
306+
// The op reads the accumulator if `useD` is not known to be false.
307+
APInt useD;
308+
if (!matchPattern(getUseD(), m_ConstantInt(&useD)) || !useD.isZero()) {
309+
effects.emplace_back(MemoryEffects::Read::get(), &getDMutable(),
310+
TensorMemory::get());
311+
}
299312
effects.emplace_back(MemoryEffects::Write::get(), &getDMutable(),
300313
TensorMemory::get());
314+
301315
if (isa<SharedMemorySpaceAttr>(getA().getType().getMemorySpace())) {
302316
effects.emplace_back(MemoryEffects::Read::get(), &getAMutable(),
303317
SharedMemory::get());
@@ -488,10 +502,12 @@ void TMEMAllocOp::getEffects(
488502
// op.
489503
if (!getType().getMutableMemory() && !op->hasAttr("tensor_memory_col_offset"))
490504
return;
491-
effects.emplace_back(MemoryEffects::Allocate::get(), TensorMemory::get());
505+
OpResult alloc = getOperation()->getOpResult(0);
506+
effects.emplace_back(MemoryEffects::Allocate::get(), alloc,
507+
TensorMemory::get());
492508
if (getSrc())
493-
effects.emplace_back(MemoryEffects::Write::get(),
494-
getOperation()->getOpResult(0), TensorMemory::get());
509+
effects.emplace_back(MemoryEffects::Write::get(), alloc,
510+
TensorMemory::get());
495511
}
496512

497513
// -- TMEMCopyOp --

lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_triton_library(TritonNvidiaGPUTransforms
22
FenceInsertion.cpp
3+
InterleaveTMem.cpp
34
MMALowering.cpp
45
OptimizeDescriptorEncoding.cpp
56
OptimizeTMemLayouts.cpp
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2+
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
3+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
4+
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
5+
#include "llvm/ADT/AddressRanges.h"
6+
7+
namespace {
8+
9+
using namespace mlir;
10+
11+
namespace ttng = triton::nvidia_gpu;
12+
namespace ttg = triton::gpu;
13+
namespace tt = triton;
14+
15+
#define GEN_PASS_CLASSES
16+
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
17+
18+
// If we don't know the effects of the op, we add all possible effects.
19+
void addAllValuelessEffects(
20+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
21+
effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Read>());
22+
effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Write>());
23+
effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Allocate>());
24+
effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Free>());
25+
}
26+
27+
bool collectEffects(Operation *op,
28+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
29+
// Collect effect instances the operation. Note that the implementation of
30+
// getEffects erases all effect instances that have the type other than the
31+
// template parameter so we collect them first in a local buffer and then
32+
// copy.
33+
if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
34+
SmallVector<MemoryEffects::EffectInstance> localEffects;
35+
iface.getEffects(localEffects);
36+
llvm::append_range(effects, localEffects);
37+
return true;
38+
}
39+
if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) {
40+
for (auto &region : op->getRegions()) {
41+
for (auto &block : region) {
42+
for (auto &innerOp : block)
43+
if (!collectEffects(&innerOp, effects))
44+
return false;
45+
}
46+
}
47+
return true;
48+
}
49+
50+
// We need to be conservative here in case the op doesn't have the interface
51+
// and assume it can have any possible effect.
52+
addAllValuelessEffects(effects);
53+
return false;
54+
}
55+
56+
struct AccessRange {
57+
SmallVector<std::optional<llvm::AddressRange>> ranges;
58+
unsigned rankOffset = 0;
59+
};
60+
61+
// Simple local alias analysis that looks for a single underlying allocation and
62+
// an access subrange.
63+
std::pair<Value, AccessRange> findBufferAccess(Value a) {
64+
// Handle block arguments.
65+
if (auto arg = dyn_cast<BlockArgument>(a)) {
66+
Operation *parentOp = arg.getOwner()->getParentOp();
67+
68+
// Look through `ttg.warp_specialize` explicit captures.
69+
if (auto wsOp = dyn_cast<ttg::WarpSpecializePartitionsOp>(parentOp)) {
70+
return findBufferAccess(
71+
wsOp.getParentOp().getExplicitCaptures()[arg.getArgNumber()]);
72+
}
73+
74+
// Unknown block argument.
75+
return {};
76+
}
77+
78+
Operation *defOp = a.getDefiningOp();
79+
// Accessing the alloc accesses the whole buffer.
80+
if (auto alloc = dyn_cast<ttng::TMEMAllocOp>(defOp)) {
81+
AccessRange access;
82+
for (uint64_t dim : alloc.getType().getShape())
83+
access.ranges.push_back({{0, dim}});
84+
return {a, std::move(access)};
85+
}
86+
87+
// Trans and Reshape views don't change the access size.
88+
if (isa<ttg::MemDescTransOp, ttg::MemDescReshapeOp>(defOp)) {
89+
return findBufferAccess(defOp->getOperand(0));
90+
}
91+
92+
// Subviews can reduce the access sizes.
93+
if (auto subview = dyn_cast<ttg::MemDescSubviewOp>(defOp)) {
94+
auto [alloc, parentAccess] = findBufferAccess(subview.getSrc());
95+
if (!alloc)
96+
return {};
97+
// Handle subview of a subview. The first `rankOffset` access sizes are
98+
// the same as in the parent access.
99+
AccessRange childAccess;
100+
for (auto i : llvm::seq(parentAccess.rankOffset))
101+
childAccess.ranges.push_back(parentAccess.ranges[i]);
102+
103+
// The subview may have a smaller rank, in which case its access size is
104+
// just 1 for the higher dims.
105+
childAccess.rankOffset =
106+
subview.getSrc().getType().getRank() - subview.getType().getRank();
107+
for (auto [i, offset] : llvm::enumerate(subview.getOffsets())) {
108+
auto parentRange = parentAccess.ranges[i + parentAccess.rankOffset];
109+
if (!parentRange) {
110+
childAccess.ranges.push_back({});
111+
continue;
112+
}
113+
114+
// If the offset is not known, then the entire dim may be accessed.
115+
APInt value;
116+
if (!matchPattern(offset, m_ConstantInt(&value))) {
117+
childAccess.ranges.push_back({});
118+
continue;
119+
}
120+
121+
uint64_t accessStart = parentRange->start() + value.getSExtValue();
122+
uint64_t accessSize = 1;
123+
if (i >= childAccess.rankOffset)
124+
accessSize = subview.getType().getShape()[i - childAccess.rankOffset];
125+
childAccess.ranges.push_back({{accessStart, accessStart + accessSize}});
126+
}
127+
return {alloc, std::move(childAccess)};
128+
}
129+
130+
// Subslice is a subview only on the N dimension.
131+
if (auto subslice = dyn_cast<ttng::TMEMSubSliceOp>(defOp)) {
132+
auto [alloc, parentAccess] = findBufferAccess(subslice.getSrc());
133+
if (!alloc)
134+
return {};
135+
if (!parentAccess.ranges[1])
136+
return {alloc, parentAccess};
137+
uint64_t mStart = parentAccess.ranges[1]->start() + subslice.getN();
138+
uint64_t mSize = subslice.getType().getShape()[1];
139+
AccessRange childAccess = parentAccess;
140+
childAccess.ranges[1] = {{mStart, mStart + mSize}};
141+
return {alloc, std::move(childAccess)};
142+
}
143+
144+
// Unknown defining op.
145+
return {};
146+
}
147+
148+
bool tmemMayAlias(Value a, Value b) {
149+
auto [aAlloc, aRanges] = findBufferAccess(a);
150+
auto [bAlloc, bRanges] = findBufferAccess(b);
151+
// If the underlying buffer was not identified, assume mayalias.
152+
if (!aAlloc || !bAlloc)
153+
return true;
154+
// If the buffers are different, they don't alias.
155+
if (aAlloc != bAlloc)
156+
return false;
157+
// If the access ranges along any dimension are known to not overlap, then the
158+
// accesses don't alias.
159+
for (auto [aRange, bRange] : llvm::zip(aRanges.ranges, bRanges.ranges)) {
160+
// If either access range at this dim is unknown, we can't determine if they
161+
// don't overlap.
162+
if (!aRange || !bRange)
163+
continue;
164+
// The access ranges are known and don't overlap.
165+
if (!aRange->intersects(*bRange))
166+
return false;
167+
}
168+
return true;
169+
}
170+
171+
// Sink tmem_loads as close to their use as possible to reduce register
172+
// pressure.
173+
bool sinkOps(Value buffer, ArrayRef<Operation *> useChain) {
174+
Operation *insertBefore = nullptr;
175+
Operation *next = useChain.back()->getNextNode();
176+
while (next && !next->hasTrait<OpTrait::IsTerminator>()) {
177+
insertBefore = next;
178+
bool dep = false;
179+
for (auto operand : getNestedOperands(next)) {
180+
if (llvm::any_of(useChain, [&](Operation *op) {
181+
return llvm::is_contained(op->getResults(), operand);
182+
})) {
183+
dep = true;
184+
break;
185+
}
186+
}
187+
// Don't sink past barrier signals, since they may guard the liverange
188+
// of the buffer.
189+
if (isa<ttng::ArriveBarrierOp>(next))
190+
break;
191+
if (!isMemoryEffectFree(next)) {
192+
SmallVector<MemoryEffects::EffectInstance> effects;
193+
collectEffects(next, effects);
194+
for (auto effect : effects) {
195+
// Look for potentially aliasing write or free effects.
196+
if (!isa<MemoryEffects::Write, MemoryEffects::Free>(effect.getEffect()))
197+
continue;
198+
if (isa<SideEffects::DefaultResource>(effect.getResource())) {
199+
dep = true;
200+
break;
201+
}
202+
if (isa<ttng::TensorMemory>(effect.getResource()) &&
203+
(!effect.getValue() || tmemMayAlias(effect.getValue(), buffer))) {
204+
dep = true;
205+
break;
206+
}
207+
}
208+
}
209+
if (dep)
210+
break;
211+
next = next->getNextNode();
212+
}
213+
if (insertBefore && insertBefore != useChain.back()->getNextNode()) {
214+
for (Operation *op : useChain)
215+
op->moveBefore(insertBefore);
216+
return true;
217+
}
218+
return false;
219+
}
220+
221+
// Try to sink a load and a collection of its users.
222+
bool trySinkOp(Operation *op, Value buffer) {
223+
SmallVector<Operation *> useChain{op};
224+
while (useChain.back()->hasOneUse() &&
225+
isPure(*useChain.back()->user_begin()) &&
226+
useChain.back()->getNextNode() == *useChain.back()->user_begin()) {
227+
useChain.push_back(*useChain.back()->user_begin());
228+
}
229+
return sinkOps(buffer, useChain);
230+
}
231+
232+
struct TritonNvidiaGPUInterleaveTMemPass
233+
: public TritonNvidiaGPUInterleaveTMemPassBase<
234+
TritonNvidiaGPUInterleaveTMemPass> {
235+
using TritonNvidiaGPUInterleaveTMemPassBase::
236+
TritonNvidiaGPUInterleaveTMemPassBase;
237+
238+
void runOnOperation() override {
239+
MLIRContext *context = &getContext();
240+
ModuleOp m = getOperation();
241+
SmallVector<std::pair<Operation *, Value>> opsToSink;
242+
m.walk([&](Operation *op) {
243+
if (auto load = dyn_cast<ttng::TMEMLoadOp>(op))
244+
opsToSink.emplace_back(load, load.getSrc());
245+
else if (auto alloc = dyn_cast<ttng::TMEMAllocOp>(op))
246+
opsToSink.emplace_back(alloc, alloc.getResult());
247+
});
248+
for (auto [op, buffer] : opsToSink) {
249+
while (trySinkOp(op, buffer)) {
250+
// Keep trying to sink loads and their users.
251+
}
252+
}
253+
}
254+
};
255+
256+
} // namespace
257+
258+
std::unique_ptr<Pass> mlir::createTritonNvidiaGPUInterleaveTMemPass() {
259+
return std::make_unique<TritonNvidiaGPUInterleaveTMemPass>();
260+
}

0 commit comments

Comments
 (0)