Skip to content

Commit 9735dd4

Browse files
Merge commit '8de17d2525531c9cb9690405add1ed615c0410cd'
2 parents 3ba7d1e + 8de17d2 commit 9735dd4

File tree

13 files changed

+367
-287
lines changed

13 files changed

+367
-287
lines changed

include/triton/Dialect/TritonGPU/Transforms/Schedule.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@ void lowerLoops(ModuleOp moduleOp);
2828
/// Pipeline the TMA stores in the loop.
2929
bool pipelineTMAStores(scf::ForOp forOp);
3030

31-
/// Simple pipelining for the MMA ops which accumulator is modified in the loop.
32-
scf::ForOp pipelineMMAWithScaledAcc(scf::ForOp forOp);
33-
3431
/// This does post-processing on the pipelined loop to try to pipeline wgmma
3532
/// ops.
3633
// TODO: this should be included as part of the pipeline but currently the wgmma
@@ -75,6 +72,8 @@ class CoarseSchedule {
7572
}
7673

7774
bool isBefore(iterator a, iterator b) const {
75+
if (a == b)
76+
return false;
7877
for (auto it = begin(); it != end(); ++it) {
7978
if (it == a)
8079
return true;

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ add_triton_library(TritonGPUTransforms
2222
Pipeliner/SoftwarePipeliner.cpp
2323
Pipeliner/TMAStoresPipeline.cpp
2424
Pipeliner/MMAv5PipelineUtility.cpp
25-
Pipeliner/ModifiedAccMMAPipeline.cpp
2625
Pipeliner/Partition.cpp
2726
Pipeliner/PipeliningUtility.cpp
2827
Pipeliner/Schedule.cpp

lib/Dialect/TritonGPU/Transforms/HoistTMEMAlloc.cpp

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,53 @@ class RotateTMEMStoreInLoop : public OpRewritePattern<ttng::TMEMStoreOp> {
229229
}
230230
};
231231

232+
class RotateTMEMLoadInLoop : public OpRewritePattern<ttng::TMEMLoadOp> {
233+
public:
234+
using OpRewritePattern::OpRewritePattern;
235+
236+
LogicalResult matchAndRewrite(ttng::TMEMLoadOp load,
237+
PatternRewriter &rewriter) const override {
238+
scf::ForOp forOp = dyn_cast<scf::ForOp>(load->getParentOp());
239+
if (!forOp) {
240+
return failure();
241+
}
242+
if (!forOp.isDefinedOutsideOfLoop(load.getSrc())) {
243+
return failure();
244+
}
245+
if (!load.getResult().hasOneUse()) {
246+
return failure();
247+
}
248+
OpOperand &use = *load.getResult().getUses().begin();
249+
auto yield = dyn_cast<scf::YieldOp>(use.getOwner());
250+
if (!yield) {
251+
return failure();
252+
}
253+
// Create two copies of the store: one before the loop, storing the initial
254+
// value, and one before the yield, storing the value carried by the loop
255+
// arg.
256+
int argNo = use.getOperandNumber();
257+
Value initVal = forOp.getInitArgs()[argNo];
258+
rewriter.setInsertionPoint(forOp);
259+
auto vTrue = rewriter.create<arith::ConstantIntOp>(load.getLoc(), 1, 1);
260+
rewriter.create<ttng::TMEMStoreOp>(load.getLoc(), load.getSrc(), initVal,
261+
vTrue);
262+
auto attributes = load->getAttrDictionary();
263+
rewriter.moveOpBefore(load, &forOp.getBody()->front());
264+
forOp.getRegionIterArg(argNo).replaceAllUsesWith(load.getResult());
265+
load->setAttrs(attributes);
266+
267+
// Load from the tmem after the loop, and use it instead of the loop carried
268+
// value.
269+
rewriter.setInsertionPointAfter(forOp);
270+
auto loadAfterLoop = rewriter.create<ttng::TMEMLoadOp>(
271+
load.getLoc(), load.getResult().getType(), load.getSrc());
272+
forOp->getResult(argNo).replaceAllUsesWith(loadAfterLoop.getResult());
273+
// Loop carried value is no longer used, short-circuit it.
274+
yield.setOperand(argNo, forOp.getRegionIterArg(argNo));
275+
return success();
276+
}
277+
};
278+
232279
ttng::TMEMAllocOp hoistTMEMAlloc(ttng::TMEMAllocOp alloc, scf::ForOp forOp) {
233280
OpBuilder builder(alloc);
234281
builder.setInsertionPoint(forOp);
@@ -300,8 +347,10 @@ struct HoistTMEMAlloc
300347
}
301348

302349
mlir::RewritePatternSet patterns(&getContext());
303-
patterns.add<RotateTMEMStoreInLoop, CombineTMEMLoadAndStore,
304-
CombineTMEMStoreAndSelect, SinkTMEMLoad>(&getContext());
350+
patterns
351+
.add<RotateTMEMStoreInLoop, RotateTMEMLoadInLoop,
352+
CombineTMEMLoadAndStore, CombineTMEMStoreAndSelect, SinkTMEMLoad>(
353+
&getContext());
305354
scf::ForOp::getCanonicalizationPatterns(patterns, &getContext());
306355
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
307356
llvm_unreachable("Failed to hoist tmem_store");

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 132 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,9 @@ Operation *getFirstUseOfPipelinedOp(SmallVector<Operation *> ops,
118118
auto [_firstUserStage, _firstUserCluster] = schedule[firstUser];
119119
if (_useStage < _firstUserStage ||
120120
(_useStage == _firstUserStage &&
121-
schedule.clusters.isBefore(_useCluster, _firstUserCluster))) {
121+
schedule.clusters.isBefore(_useCluster, _firstUserCluster)) ||
122+
(_useStage == _firstUserStage && _useCluster == _firstUserCluster &&
123+
topLevelUser->isBeforeInBlock(firstUser))) {
122124
firstUser = topLevelUser;
123125
}
124126
}
@@ -214,6 +216,8 @@ Value createIncrementModulo(BuilderT &builder, Location loc, Value counter,
214216

215217
void replaceAllUsesDominatedBy(Operation *domOp, Value newValue, Value oldValue,
216218
DominanceInfo &domInfo) {
219+
if (newValue == oldValue)
220+
return;
217221
oldValue.replaceUsesWithIf(newValue, [&](OpOperand &use) {
218222
return domInfo.properlyDominates(domOp, use.getOwner());
219223
});
@@ -232,15 +236,6 @@ static Value createAlloc(scf::ForOp &forOp, Operation *loadOp,
232236
loadOp->getLoc(), sharedEnc, distance);
233237
}
234238

235-
template <typename BuilderT, typename... Args>
236-
Operation *createWithStage(BuilderT &builder, Location loc, int stage,
237-
CoarseSchedule::Cluster cluster, Args &&...args) {
238-
Operation *op = builder.template create<ttg::AsyncCopyGlobalToLocalOp>(
239-
loc, std::forward<Args>(args)...);
240-
241-
return op;
242-
}
243-
244239
void createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
245240
Value insertIdx, Value extractIdx,
246241
CoarseSchedule &schedule) {
@@ -874,12 +869,36 @@ std::pair<int, int> getTmemUseStageBounds(ttng::TMEMAllocOp alloc,
874869
return bounds;
875870
}
876871

877-
void createBarrierAndWaitOps(scf::ForOp forOp, CoarseSchedule &schedule,
878-
ttng::MMAv5OpInterface mma,
879-
ttng::TMEMAllocOp alloc, Value &phase,
880-
Value &barrierIdx, int numStages) {
872+
// Create a predicate argument for the dist-1wait
873+
scf::ForOp prepLoopForDist1Wait(scf::ForOp forOp, CoarseSchedule &schedule,
874+
ttng::MMAv5OpInterface mma) {
875+
OpBuilderForStage builder(forOp, schedule);
876+
Location loc = mma.getLoc();
877+
Value vFalse = builder.create<arith::ConstantIntOp>(loc, 0, 1);
878+
879+
// Create a predicate for the wait (start with false and change to true on the
880+
// first mma execution)
881+
scf::ForOp newForOp = replaceForOpWithNewSignature(builder, forOp, {vFalse});
882+
forOp.erase();
883+
forOp = newForOp;
884+
885+
builder.setInsertionPointAfter(mma);
886+
builder.setStageCluster(schedule[mma]);
887+
Value vTrue = builder.create<arith::ConstantIntOp>(loc, 1, 1);
888+
889+
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
890+
yieldOp.getResultsMutable().append(vTrue); // predicate
891+
return forOp;
892+
}
893+
894+
scf::ForOp createBarrierAndWaitOps(scf::ForOp forOp, CoarseSchedule &schedule,
895+
ttng::MMAv5OpInterface mma,
896+
ttng::TMEMAllocOp alloc, int phaseArgIdx,
897+
int barrierIdxArgIdx, int numStages) {
881898
OpBuilderForStage builder(forOp, schedule);
882899
DominanceInfo domInfo(forOp);
900+
Value phase = forOp.getRegionIterArg(phaseArgIdx);
901+
Value barrierIdx = forOp.getRegionIterArg(barrierIdxArgIdx);
883902
Value zero = builder.create<arith::ConstantIntOp>(forOp.getLoc(), 0, 32);
884903
Value one = builder.create<arith::ConstantIntOp>(forOp.getLoc(), 1, 32);
885904
Value numStagesVal =
@@ -889,8 +908,11 @@ void createBarrierAndWaitOps(scf::ForOp forOp, CoarseSchedule &schedule,
889908
builder.setStageCluster(schedule[mma]);
890909
builder.setInsertionPoint(mma);
891910
Location loc = mma->getLoc();
892-
Value barrierSlice =
893-
triton::createSingleBufferView(builder, barrierAlloc, barrierIdx);
911+
Value barrierSlice = barrierAlloc;
912+
if (numStages > 1) {
913+
barrierSlice =
914+
triton::createSingleBufferView(builder, barrierAlloc, barrierIdx);
915+
}
894916
mma.setBarrier(barrierSlice);
895917

896918
// List of buffers that may be used until wait completes
@@ -953,27 +975,85 @@ void createBarrierAndWaitOps(scf::ForOp forOp, CoarseSchedule &schedule,
953975
}
954976
}
955977

978+
std::optional<int> predicateArgIdx;
956979
for (auto wp : waitPoints) {
957980
builder.setStageCluster({wp.stage, wp.cluster});
958981
builder.setInsertionPoint(wp.op);
959-
builder.create<ttng::WaitBarrierOp>(loc, barrierSlice, phase, waitBuffers);
982+
SmallVector<Value> currWaitBuffers = waitBuffers;
983+
Value pred = nullptr;
984+
if (!domInfo.properlyDominates(mma, wp.op)) {
985+
// Waits before the mma are special: they are dist-1 to the mma which
986+
// means two things:
987+
// 1. The wait should not be executed before the mma executes for the
988+
// first time
989+
// 2. The waitBuffers do not dominate this wait. We either need to hoist
990+
// the local allocs out of the loop, or have dist-1 buffer references.
991+
if (!predicateArgIdx) {
992+
forOp = prepLoopForDist1Wait(forOp, schedule, mma);
993+
predicateArgIdx = forOp.getInitArgs().size() - 1;
994+
// HACK: Clear the wait buffers. This is meant to avoid having to have
995+
// dist-1 buffer references, or hoisting the local allocs out of the
996+
// loop. This works as long as the wait is in the same stage as the mma,
997+
// and is the reason for which we prevent pipelining the mma if there
998+
// are loads from the accumulator before the mma in a different stage.
999+
currWaitBuffers.clear();
1000+
}
1001+
pred = forOp.getRegionIterArg(*predicateArgIdx);
1002+
}
1003+
Value barrierSlice = barrierAlloc;
1004+
if (numStages > 1) {
1005+
barrierSlice =
1006+
triton::createSingleBufferView(builder, barrierAlloc, barrierIdx);
1007+
}
1008+
builder.create<ttng::WaitBarrierOp>(loc, barrierSlice, phase, pred,
1009+
currWaitBuffers);
9601010
}
9611011

9621012
builder.setStageCluster(schedule[mma]);
9631013
builder.setInsertionPoint(forOp.getBody()->getTerminator());
964-
Value barWrap;
965-
barrierIdx = createIncrementModulo(builder, loc, barrierIdx, numStagesVal,
966-
zero, one, &barWrap);
967-
phase = builder.create<arith::SelectOp>(
968-
loc, phase.getType(), barWrap,
969-
builder.create<arith::XOrIOp>(loc, phase, one), phase);
1014+
Value newPhase = builder.create<arith::XOrIOp>(loc, phase, one);
1015+
Value newBarrierIdx = barrierIdx;
1016+
if (numStages > 1) {
1017+
Value barWrap;
1018+
newBarrierIdx = createIncrementModulo(builder, loc, barrierIdx,
1019+
numStagesVal, zero, one, &barWrap);
1020+
newPhase = builder.create<arith::SelectOp>(loc, phase.getType(), barWrap,
1021+
newPhase, phase);
1022+
}
1023+
if (predicateArgIdx) {
1024+
// If there is a wait predicate, we need to select the phase and barrierIdx
1025+
// based on the predicate
1026+
Value pred = forOp.getRegionIterArg(*predicateArgIdx);
1027+
newPhase = builder.create<arith::SelectOp>(loc, phase.getType(), pred,
1028+
newPhase, phase);
1029+
newBarrierIdx = builder.create<arith::SelectOp>(loc, phase.getType(), pred,
1030+
newBarrierIdx, barrierIdx);
1031+
}
1032+
replaceAllUsesDominatedBy(newPhase.getDefiningOp(), newPhase, phase, domInfo);
1033+
replaceAllUsesDominatedBy(newBarrierIdx.getDefiningOp(), newBarrierIdx,
1034+
barrierIdx, domInfo);
1035+
1036+
// If there is a dist-1 wait, we need to add a wait after the loop
1037+
if (predicateArgIdx) {
1038+
builder.setInsertionPointAfter(forOp);
1039+
Value barrierSlice = barrierAlloc;
1040+
Value barrierIdx = forOp.getResult(barrierIdxArgIdx);
1041+
Value phase = forOp.getResult(phaseArgIdx);
1042+
if (numStages > 1) {
1043+
barrierSlice =
1044+
triton::createSingleBufferView(builder, barrierAlloc, barrierIdx);
1045+
}
1046+
builder.create<ttng::WaitBarrierOp>(loc, barrierSlice, phase,
1047+
forOp.getResult(*predicateArgIdx));
1048+
}
1049+
return forOp;
9701050
}
9711051

9721052
void multibufferTensorMemory(scf::ForOp forOp, CoarseSchedule &schedule,
973-
ttng::TMEMAllocOp alloc, Value &bufIdx,
974-
int bufIdxArgIdx, int tmemUseNumStages) {
1053+
ttng::TMEMAllocOp alloc, int bufIdxArgIdx,
1054+
int tmemUseNumStages) {
9751055
DominanceInfo domInfo(forOp);
976-
1056+
Value bufIdx = forOp.getRegionIterArg(bufIdxArgIdx);
9771057
SmallVector<std::pair<Operation *, Value>> bufIdxDefs;
9781058
auto getCurrBufIdx = [&](Operation *op) {
9791059
for (auto [_op, _val] : llvm::reverse(bufIdxDefs)) {
@@ -1077,17 +1157,16 @@ void multibufferTensorMemory(scf::ForOp forOp, CoarseSchedule &schedule,
10771157
"accumulator, and the mma uses the accumulator all the time.");
10781158
}
10791159
alloc->erase();
1080-
bufIdx = bufIdxDefs.back().second;
1160+
Value newBufIdx = bufIdxDefs.back().second;
1161+
replaceAllUsesDominatedBy(newBufIdx.getDefiningOp(), newBufIdx, bufIdx,
1162+
domInfo);
10811163
}
10821164

10831165
scf::ForOp lowerMMA(ttng::MMAv5OpInterface mma, scf::ForOp forOp,
10841166
CoarseSchedule &schedule) {
10851167
auto isLoadPipelineable = [&](Operation *op) {
10861168
return schedule[mma].first > schedule[op].first;
10871169
};
1088-
if (!mmaHasPipelineableOperands(mma, forOp, isLoadPipelineable)) {
1089-
return forOp;
1090-
}
10911170
auto alloc = mma.getAccumulator().getDefiningOp<ttng::TMEMAllocOp>();
10921171
if (!alloc) {
10931172
return forOp;
@@ -1099,7 +1178,8 @@ scf::ForOp lowerMMA(ttng::MMAv5OpInterface mma, scf::ForOp forOp,
10991178
int tmemUseNumStages =
11001179
tmemUseStageBounds.second - tmemUseStageBounds.first + 1;
11011180
int waitNumStages = tmemUseStageBounds.second - schedule[mma].first + 1;
1102-
if (waitNumStages == 1 && !hasAccReadModifyWrite(mma, forOp)) {
1181+
if (waitNumStages == 1 && !hasAccReadModifyWrite(mma, forOp) &&
1182+
mmaHasPipelineableOperands(mma, forOp, isLoadPipelineable)) {
11031183
// Overlap the mma with itself, even if there is no use of the accumulator
11041184
// after the mma
11051185
waitNumStages = 2;
@@ -1112,35 +1192,40 @@ scf::ForOp lowerMMA(ttng::MMAv5OpInterface mma, scf::ForOp forOp,
11121192
// Add arguments to the forOp
11131193
unsigned newOperandIndex = forOp.getInitArgs().size();
11141194
SmallVector<Value> newOperands = {
1115-
zero, // phase
1116-
zero, // barrierIdx
1117-
minusOne, // bufIdx
1195+
zero, // phase
1196+
zero, // barrierIdx
11181197
};
1198+
if (tmemUseNumStages > 1) {
1199+
newOperands.push_back(minusOne); // bufIdx
1200+
}
11191201
scf::ForOp newForOp =
11201202
replaceForOpWithNewSignature(builder, forOp, newOperands);
11211203
forOp.erase();
11221204
forOp = newForOp;
11231205

1124-
Value phase = forOp.getRegionIterArg(newOperandIndex + 0);
1125-
Value barrierIdx = forOp.getRegionIterArg(newOperandIndex + 1);
1126-
Value bufIdx = forOp.getRegionIterArg(newOperandIndex + 2);
1206+
int phaseArgIdx = newOperandIndex + 0;
1207+
int barrierIdxArgIdx = newOperandIndex + 1;
1208+
int bufIdxArgIdx = newOperandIndex + 2;
1209+
Value phase = forOp.getRegionIterArg(phaseArgIdx);
1210+
Value barrierIdx = forOp.getRegionIterArg(barrierIdxArgIdx);
11271211

1128-
if (waitNumStages > 1) {
1129-
createBarrierAndWaitOps(forOp, schedule, mma, alloc, phase, barrierIdx,
1130-
waitNumStages);
1212+
SmallVector<Value> newYieldOperands = {phase, barrierIdx};
1213+
if (tmemUseNumStages > 1) {
1214+
Value bufIdx = forOp.getRegionIterArg(bufIdxArgIdx);
1215+
newYieldOperands.push_back(bufIdx);
11311216
}
1217+
cast<scf::YieldOp>(forOp.getBody()->getTerminator())
1218+
.getResultsMutable()
1219+
.append(newYieldOperands);
1220+
1221+
forOp = createBarrierAndWaitOps(forOp, schedule, mma, alloc, phaseArgIdx,
1222+
barrierIdxArgIdx, waitNumStages);
11321223

11331224
if (tmemUseNumStages > 1) {
1134-
multibufferTensorMemory(forOp, schedule, alloc, bufIdx, newOperandIndex + 2,
1225+
multibufferTensorMemory(forOp, schedule, alloc, bufIdxArgIdx,
11351226
tmemUseNumStages);
11361227
}
11371228

1138-
SmallVector<Value> newYieldOperands;
1139-
newYieldOperands.push_back(phase);
1140-
newYieldOperands.push_back(barrierIdx);
1141-
newYieldOperands.push_back(bufIdx);
1142-
appendToForOpYield(forOp, newYieldOperands);
1143-
11441229
return forOp;
11451230
}
11461231

0 commit comments

Comments
 (0)