Skip to content

Commit 7af8cad

Browse files
authored
[TritonGPU] Support persistent matmul in warp specialization (#6239)
This PR extends the "pattern" for load-MMA warp specialization to support persistent kernels with MMAv5. This leverages more of the existing MMAv5 pipelining code in `TC05MMAPipeline.cpp`, primarily the analysis part which determines if the op can be pipelined and determines the accumulator override point. Thus, it is performed over the flattened loop. However, because warp specialization is async and cannot rely on execution order, there are a few cases supported by the analysis step that cannot be codegen'd at the moment. (There are likewise cases that could be codegened that aren't supported by the analysis, but these cases can be ironed out on an as-needed basis). At a high level the extended "pattern" now looks for users of the accumulator other than MMA op itself in the next iteration, and if it finds any, places the users in a new partition and adds additional synchronization, multi-buffering the accumulator if needed. This allows the epilogue, which is a conditional user of the accumulator, to be placed in its own partition, overlapping the epilogue with the load<->MMA loop. The accumulator can also be multi-buffered, enabling the next MMA to start running before the TMEM load completes in the user partition. This PR has lots of code motion due to refactoring utilities to be more widely available: * Move MMAInfo and the analysis to determine it into `MMAv5PipelineUtility.h` * Move some utilities from `PipeliningUtility.h` to `Utility.h` * Some misc code cleanup and bugfixes along the way * Fix lowering of tensordesc ops to insert addrspacecast when the pointer types are actually different. Grid constant tensordescs have to be generic address space due to NVPTX backend restriction/bug, but we treat them as addrspace=1 pointers internally. Performance results for `matmul_kernel_tma_persistent` on `M, N, K = 8192, 8192, 512 in `09-persistent-matmul.py`: * With SWP, the best config is `BLOCK_{M,N,K} = (128, 256, 64)`, 4 stages and 4 warps at 1088 TFLOPS * With WS, the best config is `BLOCK_{M, N, K} = (128, 256, 128)`, 4 stages and 4 warps at 1140 TFLOPS That's about a ~5% increase in performance.
1 parent fcf33a3 commit 7af8cad

29 files changed

+1700
-695
lines changed

include/triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h

Lines changed: 74 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,105 @@
11
#ifndef TRITON_TRITONGPU_TRANSFORMS_MMAV5PIPELINEUTILITY_H_
22
#define TRITON_TRITONGPU_TRANSFORMS_MMAV5PIPELINEUTILITY_H_
33

4-
#include <functional>
5-
#include <optional>
6-
#include <tuple>
4+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
75

86
namespace mlir {
7+
98
class OpBuilder;
10-
class Operation;
9+
class DominanceInfo;
1110

1211
namespace scf {
1312
class ForOp;
14-
}
13+
} // namespace scf
1514
namespace triton::nvidia_gpu {
16-
class MMAv5OpInterface;
17-
class TMEMAllocOp;
18-
class TMEMLoadOp;
15+
16+
//===----------------------------------------------------------------------===//
17+
// MMAInfo
18+
//===----------------------------------------------------------------------===//
19+
20+
// This struct contains analysis information about an MMAv5 operation inside a
21+
// loop used for pipelining MMA ops.
22+
struct MMAInfo {
23+
// This struct contains information about when the MMA's accumulator is
24+
// overridden in the loop, if it is at all.
25+
struct AccOverridePoint {
26+
// The operation which overrides the accumulator.
27+
Operation *op;
28+
// The condition on which the accumulator is reset.
29+
Value condition = nullptr;
30+
// The initial value of the accumulator and the value after a reset.
31+
Value initValue = nullptr;
32+
// The number of loop iterations ago the accumulator was reset.
33+
int distance = 0;
34+
// Whether the accumulator is reset via setting the `useAcc` flag to false
35+
// or by clearing the accumulator tensor value.
36+
bool isFlag = false;
37+
};
38+
39+
// The TMEM allocation of the accumuator, which directly precedes the dot op.
40+
TMEMAllocOp accAlloc;
41+
// The TMEM load of the accumulator value out of TMEM, which directly follows
42+
// the dot op.
43+
TMEMLoadOp accLoad;
44+
// The override point of the accumulator value, if it is overriden in the
45+
// loop. E.g. this is typically present for persistent kernels.
46+
std::optional<AccOverridePoint> accDef;
47+
// If the accumulator is used in future iterations of the loop, this is the
48+
// iter arg number.
49+
std::optional<int> yieldArgNo;
50+
// Whether the accumulator needs to be multibuffered.
51+
bool accIsMultiBuffered;
52+
53+
Value phase = nullptr;
54+
Value barrierIdx = nullptr;
55+
Value accInsertIdx = nullptr;
56+
Value accExtractIdx = nullptr;
57+
Value barrierAlloc = nullptr;
58+
};
59+
60+
//===----------------------------------------------------------------------===//
61+
// MMA Pipeline Analysis
62+
//===----------------------------------------------------------------------===//
1963

2064
// Returns the TMEMAllocOp and TMEMLoadOp that are used to allocate and load the
2165
// accumulator for the given MMA operation. The TMEMAllocOp and TMEMLoadOp must
2266
// be in the same region as the MMA operation.
2367
std::optional<std::pair<TMEMAllocOp, TMEMLoadOp>>
2468
getTMemAllocAndLoad(MMAv5OpInterface mmaOp);
69+
// Get immediate users of the accumulator within the current loop iteration.
70+
SmallVector<Operation *> getDirectAccUses(TMEMLoadOp accDef);
71+
// Analyze an MMA op inside a loop to determine information about how it can be
72+
// pipelined. Returns `std::nullopt` if it cannot be pipelined.
73+
std::optional<MMAInfo> getMMAInfo(scf::ForOp forOp, MMAv5OpInterface mmaOp,
74+
DominanceInfo &domInfo);
75+
76+
//===----------------------------------------------------------------------===//
77+
// MMA Pipeline Rewriters
78+
//===----------------------------------------------------------------------===//
79+
2580
// Create a new TMEMAllocOp to use for the pipelined MMA operation. It is
2681
// optionally multi-buffered based on the number of stages.
2782
TMEMAllocOp createTMemAlloc(OpBuilder &builder, TMEMAllocOp oldTMemAllocOp,
2883
bool multiBufferred, int numStages);
2984

85+
// Create a store op of the initial value of the accumulator into the
86+
// potentially multi-buffered accumulator.
87+
void createInitStore(OpBuilder &builder, TMEMAllocOp allocOp, Value initVal,
88+
bool multiBufferred);
89+
3090
// Return true if operands of the MMA operation are/are going to be pipelined
3191
// and multibuffered, enabling the MMA operation to be pipelined.
3292
bool mmaHasPipelineableOperands(
3393
MMAv5OpInterface mma, scf::ForOp forOp,
3494
std::function<bool(Operation *)> isLoadPipelineable);
3595

36-
// Return true if the loop has a read-modify-write access to the accumulator.
96+
// Return true if the accumulator of an mma in subsequent iterations is either
97+
// independent from the previous iteration (overwritten) or completely reused,
98+
// without read-modify-write.
99+
// Otherwise, we can not pipeline the MMA, as we need to insert a wait after the
100+
// mma to read back the accumulator for RMW.
37101
bool hasAccReadModifyWrite(MMAv5OpInterface mma, scf::ForOp forOp);
102+
38103
} // namespace triton::nvidia_gpu
39104
} // namespace mlir
40105

include/triton/Dialect/TritonGPU/Transforms/Partition.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ class WarpSchedule {
4747
void insert(Operation *op) { ops.push_back(op); }
4848

4949
private:
50+
void setIndex(int idx) { this->idx = idx; }
51+
friend class WarpSchedule;
52+
5053
// The partition number.
5154
int idx;
5255
// The stage of the partition.
@@ -57,6 +60,8 @@ class WarpSchedule {
5760

5861
// Create a new partition with a stage.
5962
Partition *addPartition(unsigned stage);
63+
// Give each partition a new index and order. The indices must be unique.
64+
void reorderPartitions(ArrayRef<unsigned> order);
6065

6166
// Get the partition the op belongs to.
6267
Partition *getPartition(Operation *op);

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ namespace triton {
1414
static const char *kNumStagesAttrName = "tt.num_stages";
1515
static const char *kDisallowAccMultiBufferAttrName =
1616
"tt.disallow_acc_multi_buffer";
17+
static const char *kWarpSpecializeAttrName = "tt.warp_specialize";
1718
static const char *kLoopStageAttrName = "loop.stage";
1819
static const char *kLoopClusterAttrName = "loop.cluster";
1920
static const char *kScheduledMaxStageAttrName = "tt.scheduled_max_stage";
@@ -38,17 +39,6 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
3839
// `tt.disallow_acc_multi_buffer` set to true.
3940
bool getDisallowAccMultiBuffer(scf::ForOp forOp);
4041

41-
/// Visit the operands of `op` and the operands of any nested ops defined
42-
/// outside of `op`.
43-
void visitNestedOperands(Operation *op,
44-
function_ref<void(OpOperand &)> visitor);
45-
/// Visit the operands of `op` and the operands of any nested ops defined
46-
/// outside of `op`.
47-
void visitNestedOperands(Operation *op, function_ref<void(Value)> visitor);
48-
/// Get the operands of `op` and the operands of any nested ops defined outside
49-
/// of `op`.
50-
SetVector<Value> getNestedOperands(Operation *op);
51-
5242
// Return the definition of the given value. If the value is a loop-carried
5343
// dependency, return the definition and the distance to it.
5444
std::pair<OpResult, int64_t> getDefinitionAndDistance(scf::ForOp forOp,
@@ -90,10 +80,6 @@ gpu::SharedEncodingTrait getSharedEncoding(RankedTensorType ty);
9080
// Get a shared encoding for a tensor based on its uses.
9181
gpu::SharedEncodingTrait getSharedEncoding(Operation *loadOp);
9282

93-
// Erase the given loop carried values from the loop, where `loop` is replaced
94-
// with a new loop.
95-
void eraseLoopCarriedValues(scf::ForOp &loop, llvm::BitVector indices);
96-
9783
// Get the number of stages to pipeline the loop with, if it is explicitly
9884
// specified.
9985
int getNumStagesOrDefault(scf::ForOp forOp, int defaultNumStages);

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <numeric>
1111

1212
namespace mlir {
13+
class DominanceInfo;
1314

1415
namespace triton {
1516
class ModuleAxisInfoAnalysis;
@@ -135,6 +136,8 @@ scf::ForOp replaceForOpWithNewSignature(
135136
SmallVectorImpl<std::tuple<Value, Value>> &replacements);
136137
scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, scf::ForOp loop,
137138
ValueRange newIterOperands);
139+
Block::BlockArgListType addIterArgsToLoop(OpBuilder &rewriter, scf::ForOp &loop,
140+
ValueRange newIterOperands);
138141

139142
// Replace WhileOp with a new WhileOp with extra operands. The YieldOp is not
140143
// updated and needs to be updated separately for the loop to be correct.
@@ -213,6 +216,27 @@ triton::gpu::LocalAllocOp findShmemAlloc(Value operand);
213216
SmallVector<Operation *>
214217
getMMAsWithMultiBufferredOperands(scf::ForOp forOp,
215218
SmallVector<Operation *> &mmaOps);
219+
220+
// Given a list of ops, find the naerest common dominator of all ops or return
221+
// null if one could not be found. The ops are allowed to be in different
222+
// regions. The result op is not necessarily one of the ops in the list.
223+
Operation *findNearestCommonDominator(ArrayRef<Operation *> ops,
224+
DominanceInfo &domInfo);
225+
226+
/// Visit the operands of `op` and the operands of any nested ops defined
227+
/// outside of `op`.
228+
void visitNestedOperands(Operation *op,
229+
function_ref<void(OpOperand &)> visitor);
230+
/// Visit the operands of `op` and the operands of any nested ops defined
231+
/// outside of `op`.
232+
void visitNestedOperands(Operation *op, function_ref<void(Value)> visitor);
233+
/// Get the operands of `op` and the operands of any nested ops defined outside
234+
/// of `op`.
235+
SetVector<Value> getNestedOperands(Operation *op);
236+
237+
// Erase the given loop carried values from the loop, where `loop` is replaced
238+
// with a new loop.
239+
void eraseLoopCarriedValues(scf::ForOp &loop, llvm::BitVector indices);
216240
} // namespace mlir
217241

218242
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -848,7 +848,7 @@ LogicalResult WarpYieldOp::verify() {
848848
static size_t getSharedMemorySize(Type type) {
849849
if (isa<IntegerType, FloatType>(type))
850850
return llvm::divideCeil(type.getIntOrFloatBitWidth(), 8);
851-
if (isa<PointerType>(type))
851+
if (isa<PointerType, TensorDescType>(type))
852852
return 8;
853853
if (auto desc = dyn_cast<MemDescType>(type)) {
854854
if (!isa<SharedMemorySpaceAttr>(desc.getMemorySpace()))

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ add_triton_library(TritonGPUTransforms
1212
OptimizeThreadLocality.cpp
1313
Pipeliner/AssignLatencies.cpp
1414
Pipeliner/LowerLoops.cpp
15+
Pipeliner/MMAv5PipelineUtility.cpp
1516
Pipeliner/ScheduleLoops.cpp
1617
Pipeliner/WGMMAPipeline.cpp
1718
Pipeliner/PipelineExpander.cpp

lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,13 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) {
900900
epilogueIf.erase();
901901
}
902902

903+
// Propagate warp specialization flags.
904+
if (outer->hasAttr(kWarpSpecializeAttrName) ||
905+
llvm::any_of(innerLoops, [](scf::ForOp loop) {
906+
return loop->hasAttr(kWarpSpecializeAttrName);
907+
}))
908+
fused->setAttr(kWarpSpecializeAttrName, b.getUnitAttr());
909+
903910
// Propagate the `tt.disallow_acc_multi_buffer` attribute to the parent loop.
904911
bool disallowAccMultiBuffer = getDisallowAccMultiBuffer(outer);
905912
for (scf::ForOp loop : innerLoops) {

lib/Dialect/TritonGPU/Transforms/HoistTMEMAlloc.cpp

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -46,35 +46,6 @@ bool aliasingStoresBetween(Operation *op, ttng::TMEMStoreOp store) {
4646
return false;
4747
}
4848

49-
Operation *findNearestCommonDominator(ArrayRef<Operation *> ops,
50-
DominanceInfo &domInfo) {
51-
if (ops.size() == 0) {
52-
return nullptr;
53-
}
54-
if (ops.size() == 1) {
55-
return ops[0];
56-
}
57-
llvm::SmallPtrSet<Block *, 16> blocks;
58-
for (auto op : ops) {
59-
blocks.insert(op->getBlock());
60-
}
61-
Block *domBlock = domInfo.findNearestCommonDominator(blocks);
62-
if (domBlock == nullptr) {
63-
return nullptr;
64-
}
65-
SmallVector<Operation *> ancestorOps;
66-
for (auto op : ops) {
67-
ancestorOps.push_back(domBlock->findAncestorOpInBlock(*op));
68-
}
69-
Operation *dom = ancestorOps[0];
70-
for (unsigned i = 1; i < ops.size(); i++) {
71-
if (ancestorOps[i]->isBeforeInBlock(dom)) {
72-
dom = ancestorOps[i];
73-
}
74-
}
75-
return dom;
76-
}
77-
7849
class CombineTMEMStoreAndSelect : public OpRewritePattern<ttng::TMEMStoreOp> {
7950
public:
8051
using OpRewritePattern::OpRewritePattern;

lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,7 @@ class OptimizeAccumulatorInitPass
211211
}
212212

213213
Value loopArgFlagValue = loopArgIsZero ? vFalse : vTrue;
214-
scf::ForOp newForOp =
215-
replaceForOpWithNewSignature(rewriter, forOp, {loopArgFlagValue});
216-
forOp.erase();
217-
forOp = newForOp;
214+
(void)addIterArgsToLoop(rewriter, forOp, {loopArgFlagValue});
218215
loopArgFlagValue =
219216
forOp.getRegionIterArg(forOp.getNumRegionIterArgs() - 1);
220217

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -590,13 +590,10 @@ scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule) {
590590
}
591591

592592
// Patch the loop to add the new loop carried dependencies.
593-
scf::ForOp newForOp =
594-
replaceForOpWithNewSignature(builder, forOp, newOperands);
595-
forOp.erase();
596-
forOp = newForOp;
593+
(void)addIterArgsToLoop(builder, forOp, newOperands);
597594

598595
// Update yield op with temporary yield values
599-
auto forYield = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
596+
auto forYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
600597
for (unsigned i = 0; i < newOperands.size(); ++i) {
601598
forYield.getResultsMutable().append(newOperands[i]);
602599
}
@@ -605,13 +602,13 @@ scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule) {
605602
loc = forOp.getLoc();
606603
int argIdx = newOperandIndex;
607604
for (auto &[numBuffers, loadGroup] : loadGroups) {
608-
Value insertIdx = newForOp.getBody()->getArgument(argIdx);
605+
Value insertIdx = forOp.getBody()->getArgument(argIdx);
609606
argIdx++;
610-
Value extractIdx = newForOp.getBody()->getArgument(argIdx);
607+
Value extractIdx = forOp.getBody()->getArgument(argIdx);
611608
argIdx++;
612609
Value phase = nullptr;
613610
if (loadGroup.hasTMALoad) {
614-
phase = newForOp.getBody()->getArgument(argIdx);
611+
phase = forOp.getBody()->getArgument(argIdx);
615612
argIdx++;
616613
}
617614

@@ -821,25 +818,22 @@ scf::ForOp lowerTMADescriptors(scf::ForOp forOp, CoarseSchedule &schedule) {
821818
newOperands.push_back(zero);
822819
}
823820

824-
scf::ForOp newForOp =
825-
replaceForOpWithNewSignature(builder, forOp, newOperands);
826-
forOp.erase();
827-
forOp = newForOp;
821+
(void)addIterArgsToLoop(builder, forOp, newOperands);
828822

829-
auto tmaCounters = ArrayRef<BlockArgument>(newForOp.getBody()->getArguments())
823+
auto tmaCounters = ArrayRef<BlockArgument>(forOp.getBody()->getArguments())
830824
.slice(tmaCounterArgsStartIdx);
831825

832826
// Update yield op with temporary yield values
833-
auto forYield = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
827+
auto forYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
834828
for (unsigned i = 0; i < newOperands.size(); ++i) {
835829
forYield.getResultsMutable().append(newOperands[i]);
836830
}
837831

838-
if (failed(rewriteTMABufferUpdates(newForOp, tmaBufferMapping, tmaCounters,
832+
if (failed(rewriteTMABufferUpdates(forOp, tmaBufferMapping, tmaCounters,
839833
maxStage, one, zero, schedule))) {
840834
llvm_unreachable("Failed to rewrite TMA ops");
841835
}
842-
return newForOp;
836+
return forOp;
843837
}
844838

845839
/////////////////////////////

0 commit comments

Comments
 (0)