Skip to content

Commit 74d7e4d

Browse files
Mogballptillet
andauthored
[Warp Specialization] Miscellaneous fixes to MMA pipelining (#6838)
Fix up the logic in `pipelineMMA` to handle a few more edge cases. --------- Co-authored-by: Philippe Tillet <[email protected]>
1 parent 54f6d57 commit 74d7e4d

File tree

3 files changed

+188
-48
lines changed

3 files changed

+188
-48
lines changed

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp

Lines changed: 65 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "mlir/IR/Dominance.h"
55
#include "mlir/IR/ImplicitLocOpBuilder.h"
66
#include "mlir/Pass/Pass.h"
7+
#include "triton/Analysis/Utility.h"
78
#include "triton/Dialect/Triton/IR/Dialect.h"
89
#include "triton/Dialect/Triton/IR/Utility.h"
910
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -119,7 +120,8 @@ static PartitionScheme getPartitionScheme(scf::ForOp loop) {
119120
// the MMA partition.
120121
auto storeOp = dyn_cast_or_null<ttng::TMEMStoreOp>(
121122
findDefOpInLoop(loop, mmaOp.getAccDep()));
122-
if (!ttng::hasAccReadModifyWrite(mmaOp, loop) && storeOp)
123+
if (!ttng::hasAccReadModifyWrite(mmaOp, loop) && storeOp &&
124+
loop.isDefinedOutsideOfLoop(storeOp.getSrc()))
123125
mma.storeOp = storeOp;
124126

125127
// Look for views into the operands.
@@ -619,9 +621,8 @@ addIndexAndPhase(PartitionBuilder &b, scf::ForOp &loop, unsigned numStages,
619621
return {index, phase};
620622
}
621623

622-
static std::pair<Value, Operation *>
623-
getUserPrecondition(ImplicitLocOpBuilder &b, scf::ForOp loop, Operation *domOp,
624-
Value initialValue = {}) {
624+
static Value getUserPrecondition(ImplicitLocOpBuilder &b, scf::ForOp loop,
625+
Operation *domOp) {
625626
// If the use is inside a loop besides the actual loop being pipelined, we
626627
// have to hoist the use up to that loop, otherwise the barriers will be
627628
// inserted in the loop.
@@ -630,12 +631,13 @@ getUserPrecondition(ImplicitLocOpBuilder &b, scf::ForOp loop, Operation *domOp,
630631
domOp = userLoop;
631632
assert(loop->isProperAncestor(domOp));
632633

633-
Value trueVal = b.create<arith::ConstantOp>(b.getBoolAttr(true));
634634
OpBuilder::InsertionGuard guard(b);
635-
b.setInsertionPoint(loop.getBody()->findAncestorOpInBlock(*domOp));
635+
b.setInsertionPoint(loop);
636+
Value trueVal = b.create<arith::ConstantOp>(b.getBoolAttr(true));
636637

637-
Value precondition = initialValue ? initialValue : trueVal;
638+
Value userPred = trueVal;
638639
Operation *parentOp = domOp;
640+
b.setInsertionPoint(loop.getBody()->findAncestorOpInBlock(*domOp));
639641
while (loop != (parentOp = parentOp->getParentOp())) {
640642
assert(!isa<LoopLikeOpInterface>(parentOp));
641643
auto ifOp = dyn_cast<scf::IfOp>(parentOp);
@@ -646,10 +648,10 @@ getUserPrecondition(ImplicitLocOpBuilder &b, scf::ForOp loop, Operation *domOp,
646648
Value cond = ifOp.getCondition();
647649
if (domOp->getParentRegion() == &ifOp.getElseRegion())
648650
cond = b.create<arith::XOrIOp>(cond, trueVal);
649-
precondition = b.create<arith::AndIOp>(precondition, cond);
651+
userPred = b.create<arith::AndIOp>(userPred, cond);
650652
}
651653

652-
return {precondition, domOp};
654+
return userPred;
653655
}
654656

655657
static MemDescType getAsMutable(MemDescType type) {
@@ -1039,7 +1041,6 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
10391041

10401042
struct Node {
10411043
Operation *op;
1042-
Partition *partition;
10431044
Value barPrev;
10441045
Value barNext;
10451046
Value index;
@@ -1061,18 +1062,15 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
10611062
}
10621063
}
10631064

1064-
Value firstBar;
1065-
for (int i = nodes.size(); i > 0; --i) {
1066-
if ((firstBar = nodes[i % nodes.size()].barPrev))
1067-
break;
1068-
}
1069-
if (firstBar) {
1065+
// If the first node has a barrier, fully initialize it to let it run.
1066+
if (nodes.front().barPrev) {
10701067
for (auto i : llvm::seq(numMmaStages)) {
10711068
b.setInsertionPoint(loop);
1072-
Value bar = createSingleBufferView(b, firstBar, i);
1069+
Value bar = createSingleBufferView(b, nodes.front().barPrev, i);
10731070
b.create<ttng::ArriveBarrierOp>(bar, /*arriveCount=*/1);
10741071
}
10751072
}
1073+
10761074
Value userPred = b.boolCst(true);
10771075
if (readOp == mmaOp) {
10781076
PartitionBuilder b(mmaOp.getLoc(), mmaOp);
@@ -1087,14 +1085,12 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
10871085
Value replTok = b.create<ub::PoisonOp>(b.getType<AsyncTokenType>());
10881086
DenseSet<Operation *> seen;
10891087
std::optional<OpBuilder::InsertPoint> incrementPt;
1088+
Node *firstAfterInc = nullptr;
10901089
for (Node &node : nodes) {
10911090
node.index = curIndex;
10921091
node.phase = curPhase;
1093-
if (incrementPt && node.barPrev && node.barPrev != firstBar) {
1094-
b.setInsertionPoint(loop);
1095-
b.create<ttng::ArriveBarrierOp>(
1096-
createSingleBufferView(b, node.barPrev, 0), /*arriveCount=*/1);
1097-
}
1092+
if (incrementPt && node.barPrev && !firstAfterInc)
1093+
firstAfterInc = &node;
10981094
if (!seen.insert(node.op).second)
10991095
continue;
11001096
b.setInsertionPoint(node.op);
@@ -1115,7 +1111,7 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
11151111
}
11161112
if (node.op == dyn_cast<ttng::TMEMLoadOp>(readOp)) {
11171113
ImplicitLocOpBuilder b(readOp->getLoc(), loop);
1118-
userPred = getUserPrecondition(b, loop, node.op).first;
1114+
userPred = getUserPrecondition(b, loop, node.op);
11191115
b.setInsertionPointAfter(inBody(readOp));
11201116
auto [nextIndex, nextPhase] =
11211117
postIncrementModulo(b, index, phase, numMmaStages);
@@ -1124,28 +1120,57 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
11241120
incrementPt = b.saveInsertionPoint();
11251121
}
11261122
}
1123+
if (firstAfterInc) {
1124+
b.setInsertionPoint(loop);
1125+
if (firstAfterInc->op == mmaOp) {
1126+
Value firstBar = createSingleBufferView(b, firstAfterInc->barPrev, 0);
1127+
b.create<ttng::ArriveBarrierOp>(firstBar, /*arriveCount=*/1);
1128+
} else {
1129+
assert(firstAfterInc->op == dyn_cast<ttng::TMEMStoreOp>(overwriteOp));
1130+
for (auto i : llvm::seq(numMmaStages)) {
1131+
Value firstBar = createSingleBufferView(b, firstAfterInc->barPrev, i);
1132+
b.create<ttng::ArriveBarrierOp>(firstBar, /*arriveCount=*/1);
1133+
}
1134+
}
1135+
}
11271136
oldAllocOp.getToken().replaceAllUsesWith(allocOp.getToken());
11281137
oldAllocOp.erase();
11291138
cast<scf::YieldOp>(loop.getBody()->getTerminator())
11301139
.getResultsMutable()
11311140
.append({curIndex, curPhase});
11321141

11331142
// Find operands that need to be pipelined through shmem.
1143+
SmallVector<Value> incomingOperands;
1144+
llvm::append_range(incomingOperands, mmaOp->getOperands());
11341145
SmallVector<std::pair<Operation *, Partition *>> operandDefs;
1135-
for (Value operand : mma.mmaOp->getOperands()) {
1146+
while (!incomingOperands.empty()) {
1147+
Value operand = incomingOperands.pop_back_val();
1148+
if (!isa<MemDescType>(operand.getType()))
1149+
continue;
11361150
Operation *defOp = operand.getDefiningOp();
1137-
if (!defOp || !loop.getBodyRegion().isAncestor(defOp->getParentRegion()))
1151+
if (!defOp || loop.isDefinedOutsideOfLoop(operand))
11381152
continue;
11391153
defOp = inBody(defOp);
11401154
Partition *defPartition = schedule.getPartition(defOp);
1141-
if (!defPartition)
1155+
1156+
if (!defPartition) {
1157+
// If the MMA operand is coming from outside the loop, move the alloc out.
1158+
auto allocOp = dyn_cast<LocalAllocOp>(defOp);
1159+
if (allocOp && loop.isDefinedOutsideOfLoop(allocOp.getSrc()))
1160+
allocOp->moveBefore(loop);
11421161
continue;
1162+
}
1163+
11431164
if (auto allocOp = operand.getDefiningOp<LocalAllocOp>()) {
11441165
PartitionBuilder b(allocOp.getLoc(), allocOp);
11451166
auto store = b.createInto<LocalStoreOp>(*defPartition, std::nullopt,
11461167
allocOp.getSrc(), allocOp);
1168+
auto fence = b.createInto<ttng::FenceAsyncSharedOp>(
1169+
*defPartition, std::nullopt, /*bCluster=*/false);
11471170
operandDefs.emplace_back(body.findAncestorOpInBlock(*store),
11481171
defPartition);
1172+
operandDefs.emplace_back(body.findAncestorOpInBlock(*fence),
1173+
defPartition);
11491174
allocOp->moveBefore(loop);
11501175
allocOp->removeAttr(kPartitionAttrName);
11511176
allocOp.getSrcMutable().clear();
@@ -1161,6 +1186,8 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
11611186
tmemAllocOp->removeAttr(kPartitionAttrName);
11621187
tmemAllocOp.getSrcMutable().clear();
11631188
tmemAllocOp.getResult().setType(getAsMutable(tmemAllocOp.getType()));
1189+
} else if (defOp->hasTrait<OpTrait::MemDescViewTrait>()) {
1190+
incomingOperands.push_back(defOp->getOperand(0));
11641191
}
11651192
}
11661193

@@ -1187,13 +1214,20 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
11871214

11881215
if (node.barPrev) {
11891216
if (!isa<ttng::TMEMLoadOp>(node.op)) {
1190-
if (incrementPt && domOp->isBeforeInBlock(&*incrementPt->getPoint()))
1217+
// If the user precondition is defined after the MMA, we need to peel
1218+
// the wait for the user.
1219+
if (incrementPt && domOp->isBeforeInBlock(&*incrementPt->getPoint()) &&
1220+
domInfo.properlyDominates(mmaOp, userPred.getDefiningOp())) {
11911221
b.restoreInsertionPoint(*incrementPt);
1192-
else
1222+
Value bar = createSingleBufferView(b, node.barPrev, curIndex);
1223+
b.createInto<ttng::WaitBarrierOp>(*partition, stages.lookup(node.op),
1224+
bar, curPhase, userPred);
1225+
} else {
11931226
b.setInsertionPoint(domOp);
1194-
Value bar = createSingleBufferView(b, node.barPrev, curIndex);
1195-
b.createInto<ttng::WaitBarrierOp>(*partition, stages.lookup(node.op),
1196-
bar, curPhase, userPred);
1227+
Value bar = createSingleBufferView(b, node.barPrev, node.index);
1228+
b.createInto<ttng::WaitBarrierOp>(*partition, stages.lookup(node.op),
1229+
bar, node.phase, userPred);
1230+
}
11971231
} else {
11981232
b.setInsertionPoint(domOp);
11991233
if (isa<scf::IfOp>(domOp->getParentOp()))

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/RewritePartitionDependencies.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,8 @@ void DependencyRewriter::resolveOutputMultiplicity(
321321
AsyncRef DependencyRewriter::allocateAsyncValue(RankedTensorType tensorType,
322322
unsigned multiplicitySize,
323323
unsigned maxDistance) {
324+
OpBuilder::InsertionGuard guard(b);
325+
b.setInsertionPoint(loop);
324326
unsigned numBars = multiplicitySize + maxDistance;
325327
Value alloc = createAlloc(loop, tensorType, b.getLoc(),
326328
getSharedEncoding(tensorType), numBars);
@@ -340,6 +342,9 @@ AsyncRef DependencyRewriter::allocateAsyncValue(RankedTensorType tensorType,
340342
// for the buffer, store it and mark the buffer as ready to be consumed.
341343
void DependencyRewriter::initializeBarriers(int index, const AsyncRef &aref,
342344
unsigned numConsumers, Value init) {
345+
OpBuilder::InsertionGuard guard(b);
346+
b.setInsertionPoint(loop);
347+
343348
Value idx = intCst(index);
344349
if (init) {
345350
Value view = aref.getValueView(b, idx);

0 commit comments

Comments
 (0)