Skip to content

Commit 0209c69

Browse files
authored
[TritonGPU] Refactor LoadMMASpecialization (#6419)
This PR moves HoistTMEMAlloc before warp specialization and refactors LoadMMASpecialization to handle the IR the more canonical form created by HoistTMEMAlloc. This simplifies parts of the implementation and deletes the rest of the old MMAv5 pipelining code. There are still a few weird hacks in LoadMMASpecialization to connect the old code with the new code (LowerLoops, for example, prefers to read-modify-write the buffer index with `replaceAllUsesDominatedBy` whereas LoadMMASpecialization still needs to precompute the override point), but I plan to rewrite this entirely at some point.
1 parent d0b3ad8 commit 0209c69

File tree

11 files changed

+263
-540
lines changed

11 files changed

+263
-540
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,6 @@ createLLVMIntrinsicCallOp(OpBuilder &builder, Location loc, StringRef intrinsic,
5353
TypeRange types, ValueRange args);
5454
} // namespace mlir::LLVM
5555

56-
// Is v an integer or floating-point scalar constant equal to 0?
57-
bool isConstantZero(Value v);
58-
5956
namespace mlir::triton {
6057

6158
struct TritonLLVMOpBuilder {
@@ -348,9 +345,6 @@ LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op,
348345
namespace LLVM {
349346
using namespace mlir::triton;
350347

351-
// Is v an integer or floating-point scalar constant equal to 0?
352-
bool isConstantZero(Value v);
353-
354348
class SharedMemoryObject {
355349
public:
356350
SharedMemoryObject(Value base, Type baseElemType, ArrayRef<Value> offsets)

include/triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h

Lines changed: 8 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -13,50 +13,6 @@ class ForOp;
1313
} // namespace scf
1414
namespace triton::nvidia_gpu {
1515

16-
//===----------------------------------------------------------------------===//
17-
// MMAInfo
18-
//===----------------------------------------------------------------------===//
19-
20-
// This struct contains analysis information about an MMAv5 operation inside a
21-
// loop used for pipelining MMA ops.
22-
struct MMAInfo {
23-
// This struct contains information about when the MMA's accumulator is
24-
// overridden in the loop, if it is at all.
25-
struct AccOverridePoint {
26-
// The operation which overrides the accumulator.
27-
Operation *op;
28-
// The condition on which the accumulator is reset.
29-
Value condition = nullptr;
30-
// The initial value of the accumulator and the value after a reset.
31-
Value initValue = nullptr;
32-
// The number of loop iterations ago the accumulator was reset.
33-
int distance = 0;
34-
// Whether the accumulator is reset via setting the `useAcc` flag to false
35-
// or by clearing the accumulator tensor value.
36-
bool isFlag = false;
37-
};
38-
39-
// The TMEM allocation of the accumuator, which directly precedes the dot op.
40-
TMEMAllocOp accAlloc;
41-
// The TMEM load of the accumulator value out of TMEM, which directly follows
42-
// the dot op.
43-
TMEMLoadOp accLoad;
44-
// The override point of the accumulator value, if it is overriden in the
45-
// loop. E.g. this is typically present for persistent kernels.
46-
std::optional<AccOverridePoint> accDef;
47-
// If the accumulator is used in future iterations of the loop, this is the
48-
// iter arg number.
49-
std::optional<int> yieldArgNo;
50-
// Whether the accumulator needs to be multibuffered.
51-
bool accIsMultiBuffered;
52-
53-
Value phase = nullptr;
54-
Value barrierIdx = nullptr;
55-
Value accInsertIdx = nullptr;
56-
Value accExtractIdx = nullptr;
57-
Value barrierAlloc = nullptr;
58-
};
59-
6016
//===----------------------------------------------------------------------===//
6117
// MMA Pipeline Analysis
6218
//===----------------------------------------------------------------------===//
@@ -66,12 +22,14 @@ struct MMAInfo {
6622
// be in the same region as the MMA operation.
6723
std::optional<std::pair<TMEMAllocOp, TMEMLoadOp>>
6824
getTMemAllocAndLoad(MMAv5OpInterface mmaOp);
69-
// Get immediate users of the accumulator within the current loop iteration.
70-
SmallVector<Operation *> getDirectAccUses(TMEMLoadOp accDef);
71-
// Analyze an MMA op inside a loop to determine information about how it can be
72-
// pipelined. Returns `std::nullopt` if it cannot be pipelined.
73-
std::optional<MMAInfo> getMMAInfo(scf::ForOp forOp, MMAv5OpInterface mmaOp,
74-
DominanceInfo &domInfo);
25+
// Given an MMAv5 operation in a loop, determine if its accumulator can be
26+
// multibuffered.
27+
bool isAccMultibufferingPossible(MMAv5OpInterface mma, scf::ForOp forOp);
28+
// Only pipeline the loops where the MMA happens before the tmem_load, or is in
29+
// the same stage as the tmem_load. Lowering does not support the case where the
30+
// MMA is in a different stage as the tmem_load and happens after it.
31+
bool mmav5DominatesTmemLoads(
32+
scf::ForOp forOp, function_ref<bool(MMAv5OpInterface)> isMmaPipelineable);
7533

7634
//===----------------------------------------------------------------------===//
7735
// MMA Pipeline Rewriters
@@ -82,11 +40,6 @@ std::optional<MMAInfo> getMMAInfo(scf::ForOp forOp, MMAv5OpInterface mmaOp,
8240
TMEMAllocOp createTMemAlloc(OpBuilder &builder, TMEMAllocOp oldTMemAllocOp,
8341
bool multiBufferred, int numStages);
8442

85-
// Create a store op of the initial value of the accumulator into the
86-
// potentially multi-buffered accumulator.
87-
void createInitStore(OpBuilder &builder, TMEMAllocOp allocOp, Value initVal,
88-
bool multiBufferred);
89-
9043
// Return true if operands of the MMA operation are/are going to be pipelined
9144
// and multibuffered, enabling the MMA operation to be pipelined.
9245
bool mmaHasPipelineableOperands(

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
namespace mlir {
1313
class DominanceInfo;
14+
class PostDominanceInfo;
1415

1516
namespace triton {
1617
class ModuleAxisInfoAnalysis;
@@ -222,6 +223,11 @@ getMMAsWithMultiBufferredOperands(scf::ForOp forOp,
222223
// regions. The result op is not necessarily one of the ops in the list.
223224
Operation *findNearestCommonDominator(ArrayRef<Operation *> ops,
224225
DominanceInfo &domInfo);
226+
// Given a list of ops, find the naerest common postdominator of all ops or
227+
// return null if one could not be found. The ops are allowed to be in different
228+
// regions. The result op is not necessarily one of the ops in the list.
229+
Operation *findNearestCommonPostDominator(ArrayRef<Operation *> ops,
230+
PostDominanceInfo &postDomInfo);
225231

226232
/// Visit the operands of `op` and the operands of any nested ops defined
227233
/// outside of `op`.

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -667,18 +667,6 @@ createLLVMIntrinsicCallOp(OpBuilder &builder, Location loc, StringRef intrinsic,
667667
return op;
668668
}
669669

670-
bool isConstantZero(Value v) {
671-
if (auto constantOp = v.getDefiningOp<arith::ConstantOp>()) {
672-
if (auto attr = dyn_cast<IntegerAttr>(constantOp.getValue())) {
673-
return attr.getValue().isZero();
674-
}
675-
if (auto attr = dyn_cast<FloatAttr>(constantOp.getValue())) {
676-
return attr.getValue().isZero();
677-
}
678-
}
679-
return false;
680-
}
681-
682670
Value getStructFromSharedMemoryObject(Location loc,
683671
const SharedMemoryObject &smemObj,
684672
RewriterBase &rewriter) {

lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp

Lines changed: 2 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,8 @@ class AssignMMALatencies {
278278
if (auto mma = dyn_cast<ttng::MMAv5OpInterface>(&op)) {
279279
if (ttng::mmaHasPipelineableOperands(mma, forOp, isLoadPipelineable) &&
280280
!ttng::hasAccReadModifyWrite(mma, forOp) &&
281-
!getDisallowAccMultiBuffer(forOp) &&
282-
isAccMultibufferingPossible(mma, forOp)) {
281+
ttng::isAccMultibufferingPossible(mma, forOp) &&
282+
!getDisallowAccMultiBuffer(forOp)) {
283283
opLatency[&op] = 1;
284284
}
285285
}
@@ -289,55 +289,6 @@ class AssignMMALatencies {
289289
private:
290290
scf::ForOp forOp;
291291
DenseMap<Operation *, int> &opLatency;
292-
293-
bool isConstantZero(Value v) {
294-
if (auto constantOp = v.getDefiningOp<arith::ConstantOp>()) {
295-
if (auto attr = dyn_cast<IntegerAttr>(constantOp.getValue())) {
296-
return attr.getValue().isZero();
297-
}
298-
if (auto attr = dyn_cast<FloatAttr>(constantOp.getValue())) {
299-
return attr.getValue().isZero();
300-
}
301-
}
302-
return false;
303-
}
304-
305-
bool accUseFlagSetToFalse(ttng::MMAv5OpInterface mma, scf::ForOp forOp) {
306-
Value accUseFlag = mma.useAccumulator();
307-
if (isConstantZero(accUseFlag)) {
308-
return true;
309-
}
310-
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
311-
while (auto blockArg = dyn_cast<BlockArgument>(accUseFlag)) {
312-
accUseFlag = yieldOp.getOperand(blockArg.getArgNumber() - 1);
313-
}
314-
// If the accUseFlag is overwritten in the loop, we treat it as a 'false'
315-
// with condition being ~accUseFlag.
316-
return accUseFlag.getDefiningOp() &&
317-
forOp->isAncestor(accUseFlag.getDefiningOp());
318-
}
319-
320-
bool accOverwrittenInLoop(ttng::MMAv5OpInterface mma, scf::ForOp forOp) {
321-
auto tmemAlloc = mma.getAccumulator().getDefiningOp<ttng::TMEMAllocOp>();
322-
if (!tmemAlloc || !forOp.isDefinedOutsideOfLoop(tmemAlloc)) {
323-
return false;
324-
}
325-
for (auto user : tmemAlloc->getUsers()) {
326-
if (isa<ttng::TMEMStoreOp>(user) &&
327-
forOp->isAncestor(user->getParentOp())) {
328-
return true;
329-
}
330-
}
331-
return false;
332-
}
333-
334-
bool isAccMultibufferingPossible(ttng::MMAv5OpInterface mma,
335-
scf::ForOp forOp) {
336-
// If the accumulator is never overwritten in the loop, we can't multibuffer
337-
// it, as the overwrite point is the only place where we can swap the
338-
// buffer.
339-
return accUseFlagSetToFalse(mma, forOp) || accOverwrittenInLoop(mma, forOp);
340-
}
341292
};
342293

343294
} // namespace

0 commit comments

Comments
 (0)