Skip to content

Commit 8de17d2

Browse files
[PIPELINE] Add support for pipelining mma loops with modified accumulator (#6393)
This change improves mmav5 pipelining so mma with RMW accumulator are pipelined by putting the tmem load before the mma. The approach is not entirely clean, as to avoid having dist-1 dependencies on the wait_barrier we allow only pipelining of cases where the load and the mma are in the same stage, and we skip adding buffer dependencies to the wait. Second dot pipelining is still disabled by default as even with this optimization there are some performance problems when we enable it. With second dot pipelining disabled however, this change improves some internal attention kernels perf by up to 30% by reducing spilling
1 parent 05b500c commit 8de17d2

File tree

13 files changed

+367
-287
lines changed

13 files changed

+367
-287
lines changed

include/triton/Dialect/TritonGPU/Transforms/Schedule.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@ void lowerLoops(ModuleOp moduleOp);
2828
/// Pipeline the TMA stores in the loop.
2929
bool pipelineTMAStores(scf::ForOp forOp);
3030

31-
/// Simple pipelining for the MMA ops which accumulator is modified in the loop.
32-
scf::ForOp pipelineMMAWithScaledAcc(scf::ForOp forOp);
33-
3431
/// This does post-processing on the pipelined loop to try to pipeline wgmma
3532
/// ops.
3633
// TODO: this should be included as part of the pipeline but currently the wgmma
@@ -75,6 +72,8 @@ class CoarseSchedule {
7572
}
7673

7774
bool isBefore(iterator a, iterator b) const {
75+
if (a == b)
76+
return false;
7877
for (auto it = begin(); it != end(); ++it) {
7978
if (it == a)
8079
return true;

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ add_triton_library(TritonGPUTransforms
2222
Pipeliner/SoftwarePipeliner.cpp
2323
Pipeliner/TMAStoresPipeline.cpp
2424
Pipeliner/MMAv5PipelineUtility.cpp
25-
Pipeliner/ModifiedAccMMAPipeline.cpp
2625
Pipeliner/Partition.cpp
2726
Pipeliner/PipeliningUtility.cpp
2827
Pipeliner/Schedule.cpp

lib/Dialect/TritonGPU/Transforms/HoistTMEMAlloc.cpp

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,53 @@ class RotateTMEMStoreInLoop : public OpRewritePattern<ttng::TMEMStoreOp> {
229229
}
230230
};
231231

232+
class RotateTMEMLoadInLoop : public OpRewritePattern<ttng::TMEMLoadOp> {
233+
public:
234+
using OpRewritePattern::OpRewritePattern;
235+
236+
LogicalResult matchAndRewrite(ttng::TMEMLoadOp load,
237+
PatternRewriter &rewriter) const override {
238+
scf::ForOp forOp = dyn_cast<scf::ForOp>(load->getParentOp());
239+
if (!forOp) {
240+
return failure();
241+
}
242+
if (!forOp.isDefinedOutsideOfLoop(load.getSrc())) {
243+
return failure();
244+
}
245+
if (!load.getResult().hasOneUse()) {
246+
return failure();
247+
}
248+
OpOperand &use = *load.getResult().getUses().begin();
249+
auto yield = dyn_cast<scf::YieldOp>(use.getOwner());
250+
if (!yield) {
251+
return failure();
252+
}
253+
// Create two copies of the store: one before the loop, storing the initial
254+
// value, and one before the yield, storing the value carried by the loop
255+
// arg.
256+
int argNo = use.getOperandNumber();
257+
Value initVal = forOp.getInitArgs()[argNo];
258+
rewriter.setInsertionPoint(forOp);
259+
auto vTrue = rewriter.create<arith::ConstantIntOp>(load.getLoc(), 1, 1);
260+
rewriter.create<ttng::TMEMStoreOp>(load.getLoc(), load.getSrc(), initVal,
261+
vTrue);
262+
auto attributes = load->getAttrDictionary();
263+
rewriter.moveOpBefore(load, &forOp.getBody()->front());
264+
forOp.getRegionIterArg(argNo).replaceAllUsesWith(load.getResult());
265+
load->setAttrs(attributes);
266+
267+
// Load from the tmem after the loop, and use it instead of the loop carried
268+
// value.
269+
rewriter.setInsertionPointAfter(forOp);
270+
auto loadAfterLoop = rewriter.create<ttng::TMEMLoadOp>(
271+
load.getLoc(), load.getResult().getType(), load.getSrc());
272+
forOp->getResult(argNo).replaceAllUsesWith(loadAfterLoop.getResult());
273+
// Loop carried value is no longer used, short-circuit it.
274+
yield.setOperand(argNo, forOp.getRegionIterArg(argNo));
275+
return success();
276+
}
277+
};
278+
232279
ttng::TMEMAllocOp hoistTMEMAlloc(ttng::TMEMAllocOp alloc, scf::ForOp forOp) {
233280
OpBuilder builder(alloc);
234281
builder.setInsertionPoint(forOp);
@@ -300,8 +347,10 @@ struct HoistTMEMAlloc
300347
}
301348

302349
mlir::RewritePatternSet patterns(&getContext());
303-
patterns.add<RotateTMEMStoreInLoop, CombineTMEMLoadAndStore,
304-
CombineTMEMStoreAndSelect, SinkTMEMLoad>(&getContext());
350+
patterns
351+
.add<RotateTMEMStoreInLoop, RotateTMEMLoadInLoop,
352+
CombineTMEMLoadAndStore, CombineTMEMStoreAndSelect, SinkTMEMLoad>(
353+
&getContext());
305354
scf::ForOp::getCanonicalizationPatterns(patterns, &getContext());
306355
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
307356
llvm_unreachable("Failed to hoist tmem_store");

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 132 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,9 @@ Operation *getFirstUseOfPipelinedOp(SmallVector<Operation *> ops,
118118
auto [_firstUserStage, _firstUserCluster] = schedule[firstUser];
119119
if (_useStage < _firstUserStage ||
120120
(_useStage == _firstUserStage &&
121-
schedule.clusters.isBefore(_useCluster, _firstUserCluster))) {
121+
schedule.clusters.isBefore(_useCluster, _firstUserCluster)) ||
122+
(_useStage == _firstUserStage && _useCluster == _firstUserCluster &&
123+
topLevelUser->isBeforeInBlock(firstUser))) {
122124
firstUser = topLevelUser;
123125
}
124126
}
@@ -214,6 +216,8 @@ Value createIncrementModulo(BuilderT &builder, Location loc, Value counter,
214216

215217
void replaceAllUsesDominatedBy(Operation *domOp, Value newValue, Value oldValue,
216218
DominanceInfo &domInfo) {
219+
if (newValue == oldValue)
220+
return;
217221
oldValue.replaceUsesWithIf(newValue, [&](OpOperand &use) {
218222
return domInfo.properlyDominates(domOp, use.getOwner());
219223
});
@@ -232,15 +236,6 @@ static Value createAlloc(scf::ForOp &forOp, Operation *loadOp,
232236
loadOp->getLoc(), sharedEnc, distance);
233237
}
234238

235-
template <typename BuilderT, typename... Args>
236-
Operation *createWithStage(BuilderT &builder, Location loc, int stage,
237-
CoarseSchedule::Cluster cluster, Args &&...args) {
238-
Operation *op = builder.template create<ttg::AsyncCopyGlobalToLocalOp>(
239-
loc, std::forward<Args>(args)...);
240-
241-
return op;
242-
}
243-
244239
void createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
245240
Value insertIdx, Value extractIdx,
246241
CoarseSchedule &schedule) {
@@ -874,12 +869,36 @@ std::pair<int, int> getTmemUseStageBounds(ttng::TMEMAllocOp alloc,
874869
return bounds;
875870
}
876871

877-
void createBarrierAndWaitOps(scf::ForOp forOp, CoarseSchedule &schedule,
878-
ttng::MMAv5OpInterface mma,
879-
ttng::TMEMAllocOp alloc, Value &phase,
880-
Value &barrierIdx, int numStages) {
872+
// Create a predicate argument for the dist-1wait
873+
scf::ForOp prepLoopForDist1Wait(scf::ForOp forOp, CoarseSchedule &schedule,
874+
ttng::MMAv5OpInterface mma) {
875+
OpBuilderForStage builder(forOp, schedule);
876+
Location loc = mma.getLoc();
877+
Value vFalse = builder.create<arith::ConstantIntOp>(loc, 0, 1);
878+
879+
// Create a predicate for the wait (start with false and change to true on the
880+
// first mma execution)
881+
scf::ForOp newForOp = replaceForOpWithNewSignature(builder, forOp, {vFalse});
882+
forOp.erase();
883+
forOp = newForOp;
884+
885+
builder.setInsertionPointAfter(mma);
886+
builder.setStageCluster(schedule[mma]);
887+
Value vTrue = builder.create<arith::ConstantIntOp>(loc, 1, 1);
888+
889+
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
890+
yieldOp.getResultsMutable().append(vTrue); // predicate
891+
return forOp;
892+
}
893+
894+
scf::ForOp createBarrierAndWaitOps(scf::ForOp forOp, CoarseSchedule &schedule,
895+
ttng::MMAv5OpInterface mma,
896+
ttng::TMEMAllocOp alloc, int phaseArgIdx,
897+
int barrierIdxArgIdx, int numStages) {
881898
OpBuilderForStage builder(forOp, schedule);
882899
DominanceInfo domInfo(forOp);
900+
Value phase = forOp.getRegionIterArg(phaseArgIdx);
901+
Value barrierIdx = forOp.getRegionIterArg(barrierIdxArgIdx);
883902
Value zero = builder.create<arith::ConstantIntOp>(forOp.getLoc(), 0, 32);
884903
Value one = builder.create<arith::ConstantIntOp>(forOp.getLoc(), 1, 32);
885904
Value numStagesVal =
@@ -889,8 +908,11 @@ void createBarrierAndWaitOps(scf::ForOp forOp, CoarseSchedule &schedule,
889908
builder.setStageCluster(schedule[mma]);
890909
builder.setInsertionPoint(mma);
891910
Location loc = mma->getLoc();
892-
Value barrierSlice =
893-
triton::createSingleBufferView(builder, barrierAlloc, barrierIdx);
911+
Value barrierSlice = barrierAlloc;
912+
if (numStages > 1) {
913+
barrierSlice =
914+
triton::createSingleBufferView(builder, barrierAlloc, barrierIdx);
915+
}
894916
mma.setBarrier(barrierSlice);
895917

896918
// List of buffers that may be used until wait completes
@@ -953,27 +975,85 @@ void createBarrierAndWaitOps(scf::ForOp forOp, CoarseSchedule &schedule,
953975
}
954976
}
955977

978+
std::optional<int> predicateArgIdx;
956979
for (auto wp : waitPoints) {
957980
builder.setStageCluster({wp.stage, wp.cluster});
958981
builder.setInsertionPoint(wp.op);
959-
builder.create<ttng::WaitBarrierOp>(loc, barrierSlice, phase, waitBuffers);
982+
SmallVector<Value> currWaitBuffers = waitBuffers;
983+
Value pred = nullptr;
984+
if (!domInfo.properlyDominates(mma, wp.op)) {
985+
// Waits before the mma are special: they are dist-1 to the mma which
986+
// means two things:
987+
// 1. The wait should not be executed before the mma executes for the
988+
// first time
989+
// 2. The waitBuffers do not dominate this wait. We either need to hoist
990+
// the local allocs out of the loop, or have dist-1 buffer references.
991+
if (!predicateArgIdx) {
992+
forOp = prepLoopForDist1Wait(forOp, schedule, mma);
993+
predicateArgIdx = forOp.getInitArgs().size() - 1;
994+
// HACK: Clear the wait buffers. This is meant to avoid having to have
995+
// dist-1 buffer references, or hoisting the local allocs out of the
996+
// loop. This works as long as the wait is in the same stage as the mma,
997+
// and is the reason for which we prevent pipelining the mma if there
998+
// are loads from the accumulator before the mma in a different stage.
999+
currWaitBuffers.clear();
1000+
}
1001+
pred = forOp.getRegionIterArg(*predicateArgIdx);
1002+
}
1003+
Value barrierSlice = barrierAlloc;
1004+
if (numStages > 1) {
1005+
barrierSlice =
1006+
triton::createSingleBufferView(builder, barrierAlloc, barrierIdx);
1007+
}
1008+
builder.create<ttng::WaitBarrierOp>(loc, barrierSlice, phase, pred,
1009+
currWaitBuffers);
9601010
}
9611011

9621012
builder.setStageCluster(schedule[mma]);
9631013
builder.setInsertionPoint(forOp.getBody()->getTerminator());
964-
Value barWrap;
965-
barrierIdx = createIncrementModulo(builder, loc, barrierIdx, numStagesVal,
966-
zero, one, &barWrap);
967-
phase = builder.create<arith::SelectOp>(
968-
loc, phase.getType(), barWrap,
969-
builder.create<arith::XOrIOp>(loc, phase, one), phase);
1014+
Value newPhase = builder.create<arith::XOrIOp>(loc, phase, one);
1015+
Value newBarrierIdx = barrierIdx;
1016+
if (numStages > 1) {
1017+
Value barWrap;
1018+
newBarrierIdx = createIncrementModulo(builder, loc, barrierIdx,
1019+
numStagesVal, zero, one, &barWrap);
1020+
newPhase = builder.create<arith::SelectOp>(loc, phase.getType(), barWrap,
1021+
newPhase, phase);
1022+
}
1023+
if (predicateArgIdx) {
1024+
// If there is a wait predicate, we need to select the phase and barrierIdx
1025+
// based on the predicate
1026+
Value pred = forOp.getRegionIterArg(*predicateArgIdx);
1027+
newPhase = builder.create<arith::SelectOp>(loc, phase.getType(), pred,
1028+
newPhase, phase);
1029+
newBarrierIdx = builder.create<arith::SelectOp>(loc, phase.getType(), pred,
1030+
newBarrierIdx, barrierIdx);
1031+
}
1032+
replaceAllUsesDominatedBy(newPhase.getDefiningOp(), newPhase, phase, domInfo);
1033+
replaceAllUsesDominatedBy(newBarrierIdx.getDefiningOp(), newBarrierIdx,
1034+
barrierIdx, domInfo);
1035+
1036+
// If there is a dist-1 wait, we need to add a wait after the loop
1037+
if (predicateArgIdx) {
1038+
builder.setInsertionPointAfter(forOp);
1039+
Value barrierSlice = barrierAlloc;
1040+
Value barrierIdx = forOp.getResult(barrierIdxArgIdx);
1041+
Value phase = forOp.getResult(phaseArgIdx);
1042+
if (numStages > 1) {
1043+
barrierSlice =
1044+
triton::createSingleBufferView(builder, barrierAlloc, barrierIdx);
1045+
}
1046+
builder.create<ttng::WaitBarrierOp>(loc, barrierSlice, phase,
1047+
forOp.getResult(*predicateArgIdx));
1048+
}
1049+
return forOp;
9701050
}
9711051

9721052
void multibufferTensorMemory(scf::ForOp forOp, CoarseSchedule &schedule,
973-
ttng::TMEMAllocOp alloc, Value &bufIdx,
974-
int bufIdxArgIdx, int tmemUseNumStages) {
1053+
ttng::TMEMAllocOp alloc, int bufIdxArgIdx,
1054+
int tmemUseNumStages) {
9751055
DominanceInfo domInfo(forOp);
976-
1056+
Value bufIdx = forOp.getRegionIterArg(bufIdxArgIdx);
9771057
SmallVector<std::pair<Operation *, Value>> bufIdxDefs;
9781058
auto getCurrBufIdx = [&](Operation *op) {
9791059
for (auto [_op, _val] : llvm::reverse(bufIdxDefs)) {
@@ -1077,17 +1157,16 @@ void multibufferTensorMemory(scf::ForOp forOp, CoarseSchedule &schedule,
10771157
"accumulator, and the mma uses the accumulator all the time.");
10781158
}
10791159
alloc->erase();
1080-
bufIdx = bufIdxDefs.back().second;
1160+
Value newBufIdx = bufIdxDefs.back().second;
1161+
replaceAllUsesDominatedBy(newBufIdx.getDefiningOp(), newBufIdx, bufIdx,
1162+
domInfo);
10811163
}
10821164

10831165
scf::ForOp lowerMMA(ttng::MMAv5OpInterface mma, scf::ForOp forOp,
10841166
CoarseSchedule &schedule) {
10851167
auto isLoadPipelineable = [&](Operation *op) {
10861168
return schedule[mma].first > schedule[op].first;
10871169
};
1088-
if (!mmaHasPipelineableOperands(mma, forOp, isLoadPipelineable)) {
1089-
return forOp;
1090-
}
10911170
auto alloc = mma.getAccumulator().getDefiningOp<ttng::TMEMAllocOp>();
10921171
if (!alloc) {
10931172
return forOp;
@@ -1099,7 +1178,8 @@ scf::ForOp lowerMMA(ttng::MMAv5OpInterface mma, scf::ForOp forOp,
10991178
int tmemUseNumStages =
11001179
tmemUseStageBounds.second - tmemUseStageBounds.first + 1;
11011180
int waitNumStages = tmemUseStageBounds.second - schedule[mma].first + 1;
1102-
if (waitNumStages == 1 && !hasAccReadModifyWrite(mma, forOp)) {
1181+
if (waitNumStages == 1 && !hasAccReadModifyWrite(mma, forOp) &&
1182+
mmaHasPipelineableOperands(mma, forOp, isLoadPipelineable)) {
11031183
// Overlap the mma with itself, even if there is no use of the accumulator
11041184
// after the mma
11051185
waitNumStages = 2;
@@ -1112,35 +1192,40 @@ scf::ForOp lowerMMA(ttng::MMAv5OpInterface mma, scf::ForOp forOp,
11121192
// Add arguments to the forOp
11131193
unsigned newOperandIndex = forOp.getInitArgs().size();
11141194
SmallVector<Value> newOperands = {
1115-
zero, // phase
1116-
zero, // barrierIdx
1117-
minusOne, // bufIdx
1195+
zero, // phase
1196+
zero, // barrierIdx
11181197
};
1198+
if (tmemUseNumStages > 1) {
1199+
newOperands.push_back(minusOne); // bufIdx
1200+
}
11191201
scf::ForOp newForOp =
11201202
replaceForOpWithNewSignature(builder, forOp, newOperands);
11211203
forOp.erase();
11221204
forOp = newForOp;
11231205

1124-
Value phase = forOp.getRegionIterArg(newOperandIndex + 0);
1125-
Value barrierIdx = forOp.getRegionIterArg(newOperandIndex + 1);
1126-
Value bufIdx = forOp.getRegionIterArg(newOperandIndex + 2);
1206+
int phaseArgIdx = newOperandIndex + 0;
1207+
int barrierIdxArgIdx = newOperandIndex + 1;
1208+
int bufIdxArgIdx = newOperandIndex + 2;
1209+
Value phase = forOp.getRegionIterArg(phaseArgIdx);
1210+
Value barrierIdx = forOp.getRegionIterArg(barrierIdxArgIdx);
11271211

1128-
if (waitNumStages > 1) {
1129-
createBarrierAndWaitOps(forOp, schedule, mma, alloc, phase, barrierIdx,
1130-
waitNumStages);
1212+
SmallVector<Value> newYieldOperands = {phase, barrierIdx};
1213+
if (tmemUseNumStages > 1) {
1214+
Value bufIdx = forOp.getRegionIterArg(bufIdxArgIdx);
1215+
newYieldOperands.push_back(bufIdx);
11311216
}
1217+
cast<scf::YieldOp>(forOp.getBody()->getTerminator())
1218+
.getResultsMutable()
1219+
.append(newYieldOperands);
1220+
1221+
forOp = createBarrierAndWaitOps(forOp, schedule, mma, alloc, phaseArgIdx,
1222+
barrierIdxArgIdx, waitNumStages);
11321223

11331224
if (tmemUseNumStages > 1) {
1134-
multibufferTensorMemory(forOp, schedule, alloc, bufIdx, newOperandIndex + 2,
1225+
multibufferTensorMemory(forOp, schedule, alloc, bufIdxArgIdx,
11351226
tmemUseNumStages);
11361227
}
11371228

1138-
SmallVector<Value> newYieldOperands;
1139-
newYieldOperands.push_back(phase);
1140-
newYieldOperands.push_back(barrierIdx);
1141-
newYieldOperands.push_back(bufIdx);
1142-
appendToForOpYield(forOp, newYieldOperands);
1143-
11441229
return forOp;
11451230
}
11461231

0 commit comments

Comments
 (0)