4
4
#include " mlir/IR/Dominance.h"
5
5
#include " mlir/IR/ImplicitLocOpBuilder.h"
6
6
#include " mlir/Pass/Pass.h"
7
+ #include " triton/Analysis/Utility.h"
7
8
#include " triton/Dialect/Triton/IR/Dialect.h"
8
9
#include " triton/Dialect/Triton/IR/Utility.h"
9
10
#include " triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -119,7 +120,8 @@ static PartitionScheme getPartitionScheme(scf::ForOp loop) {
119
120
// the MMA partition.
120
121
auto storeOp = dyn_cast_or_null<ttng::TMEMStoreOp>(
121
122
findDefOpInLoop (loop, mmaOp.getAccDep ()));
122
- if (!ttng::hasAccReadModifyWrite (mmaOp, loop) && storeOp)
123
+ if (!ttng::hasAccReadModifyWrite (mmaOp, loop) && storeOp &&
124
+ loop.isDefinedOutsideOfLoop (storeOp.getSrc ()))
123
125
mma.storeOp = storeOp;
124
126
125
127
// Look for views into the operands.
@@ -619,9 +621,8 @@ addIndexAndPhase(PartitionBuilder &b, scf::ForOp &loop, unsigned numStages,
619
621
return {index, phase};
620
622
}
621
623
622
- static std::pair<Value, Operation *>
623
- getUserPrecondition (ImplicitLocOpBuilder &b, scf::ForOp loop, Operation *domOp,
624
- Value initialValue = {}) {
624
+ static Value getUserPrecondition (ImplicitLocOpBuilder &b, scf::ForOp loop,
625
+ Operation *domOp) {
625
626
// If the use is inside a loop besides the actual loop being pipelined, we
626
627
// have to hoist the use up to that loop, otherwise the barriers will be
627
628
// inserted in the loop.
@@ -630,12 +631,13 @@ getUserPrecondition(ImplicitLocOpBuilder &b, scf::ForOp loop, Operation *domOp,
630
631
domOp = userLoop;
631
632
assert (loop->isProperAncestor (domOp));
632
633
633
- Value trueVal = b.create <arith::ConstantOp>(b.getBoolAttr (true ));
634
634
OpBuilder::InsertionGuard guard (b);
635
- b.setInsertionPoint (loop.getBody ()->findAncestorOpInBlock (*domOp));
635
+ b.setInsertionPoint (loop);
636
+ Value trueVal = b.create <arith::ConstantOp>(b.getBoolAttr (true ));
636
637
637
- Value precondition = initialValue ? initialValue : trueVal;
638
+ Value userPred = trueVal;
638
639
Operation *parentOp = domOp;
640
+ b.setInsertionPoint (loop.getBody ()->findAncestorOpInBlock (*domOp));
639
641
while (loop != (parentOp = parentOp->getParentOp ())) {
640
642
assert (!isa<LoopLikeOpInterface>(parentOp));
641
643
auto ifOp = dyn_cast<scf::IfOp>(parentOp);
@@ -646,10 +648,10 @@ getUserPrecondition(ImplicitLocOpBuilder &b, scf::ForOp loop, Operation *domOp,
646
648
Value cond = ifOp.getCondition ();
647
649
if (domOp->getParentRegion () == &ifOp.getElseRegion ())
648
650
cond = b.create <arith::XOrIOp>(cond, trueVal);
649
- precondition = b.create <arith::AndIOp>(precondition , cond);
651
+ userPred = b.create <arith::AndIOp>(userPred , cond);
650
652
}
651
653
652
- return {precondition, domOp} ;
654
+ return userPred ;
653
655
}
654
656
655
657
static MemDescType getAsMutable (MemDescType type) {
@@ -1039,7 +1041,6 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
1039
1041
1040
1042
struct Node {
1041
1043
Operation *op;
1042
- Partition *partition;
1043
1044
Value barPrev;
1044
1045
Value barNext;
1045
1046
Value index;
@@ -1061,18 +1062,15 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
1061
1062
}
1062
1063
}
1063
1064
1064
- Value firstBar;
1065
- for (int i = nodes.size (); i > 0 ; --i) {
1066
- if ((firstBar = nodes[i % nodes.size ()].barPrev ))
1067
- break ;
1068
- }
1069
- if (firstBar) {
1065
+ // If the first node has a barrier, fully initialize it to let it run.
1066
+ if (nodes.front ().barPrev ) {
1070
1067
for (auto i : llvm::seq (numMmaStages)) {
1071
1068
b.setInsertionPoint (loop);
1072
- Value bar = createSingleBufferView (b, firstBar , i);
1069
+ Value bar = createSingleBufferView (b, nodes. front (). barPrev , i);
1073
1070
b.create <ttng::ArriveBarrierOp>(bar, /* arriveCount=*/ 1 );
1074
1071
}
1075
1072
}
1073
+
1076
1074
Value userPred = b.boolCst (true );
1077
1075
if (readOp == mmaOp) {
1078
1076
PartitionBuilder b (mmaOp.getLoc (), mmaOp);
@@ -1087,14 +1085,12 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
1087
1085
Value replTok = b.create <ub::PoisonOp>(b.getType <AsyncTokenType>());
1088
1086
DenseSet<Operation *> seen;
1089
1087
std::optional<OpBuilder::InsertPoint> incrementPt;
1088
+ Node *firstAfterInc = nullptr ;
1090
1089
for (Node &node : nodes) {
1091
1090
node.index = curIndex;
1092
1091
node.phase = curPhase;
1093
- if (incrementPt && node.barPrev && node.barPrev != firstBar) {
1094
- b.setInsertionPoint (loop);
1095
- b.create <ttng::ArriveBarrierOp>(
1096
- createSingleBufferView (b, node.barPrev , 0 ), /* arriveCount=*/ 1 );
1097
- }
1092
+ if (incrementPt && node.barPrev && !firstAfterInc)
1093
+ firstAfterInc = &node;
1098
1094
if (!seen.insert (node.op ).second )
1099
1095
continue ;
1100
1096
b.setInsertionPoint (node.op );
@@ -1115,7 +1111,7 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
1115
1111
}
1116
1112
if (node.op == dyn_cast<ttng::TMEMLoadOp>(readOp)) {
1117
1113
ImplicitLocOpBuilder b (readOp->getLoc (), loop);
1118
- userPred = getUserPrecondition (b, loop, node.op ). first ;
1114
+ userPred = getUserPrecondition (b, loop, node.op );
1119
1115
b.setInsertionPointAfter (inBody (readOp));
1120
1116
auto [nextIndex, nextPhase] =
1121
1117
postIncrementModulo (b, index, phase, numMmaStages);
@@ -1124,28 +1120,57 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
1124
1120
incrementPt = b.saveInsertionPoint ();
1125
1121
}
1126
1122
}
1123
+ if (firstAfterInc) {
1124
+ b.setInsertionPoint (loop);
1125
+ if (firstAfterInc->op == mmaOp) {
1126
+ Value firstBar = createSingleBufferView (b, firstAfterInc->barPrev , 0 );
1127
+ b.create <ttng::ArriveBarrierOp>(firstBar, /* arriveCount=*/ 1 );
1128
+ } else {
1129
+ assert (firstAfterInc->op == dyn_cast<ttng::TMEMStoreOp>(overwriteOp));
1130
+ for (auto i : llvm::seq (numMmaStages)) {
1131
+ Value firstBar = createSingleBufferView (b, firstAfterInc->barPrev , i);
1132
+ b.create <ttng::ArriveBarrierOp>(firstBar, /* arriveCount=*/ 1 );
1133
+ }
1134
+ }
1135
+ }
1127
1136
oldAllocOp.getToken ().replaceAllUsesWith (allocOp.getToken ());
1128
1137
oldAllocOp.erase ();
1129
1138
cast<scf::YieldOp>(loop.getBody ()->getTerminator ())
1130
1139
.getResultsMutable ()
1131
1140
.append ({curIndex, curPhase});
1132
1141
1133
1142
// Find operands that need to be pipelined through shmem.
1143
+ SmallVector<Value> incomingOperands;
1144
+ llvm::append_range (incomingOperands, mmaOp->getOperands ());
1134
1145
SmallVector<std::pair<Operation *, Partition *>> operandDefs;
1135
- for (Value operand : mma.mmaOp ->getOperands ()) {
1146
+ while (!incomingOperands.empty ()) {
1147
+ Value operand = incomingOperands.pop_back_val ();
1148
+ if (!isa<MemDescType>(operand.getType ()))
1149
+ continue ;
1136
1150
Operation *defOp = operand.getDefiningOp ();
1137
- if (!defOp || ! loop.getBodyRegion (). isAncestor (defOp-> getParentRegion () ))
1151
+ if (!defOp || loop.isDefinedOutsideOfLoop (operand ))
1138
1152
continue ;
1139
1153
defOp = inBody (defOp);
1140
1154
Partition *defPartition = schedule.getPartition (defOp);
1141
- if (!defPartition)
1155
+
1156
+ if (!defPartition) {
1157
+ // If the MMA operand is coming from outside the loop, move the alloc out.
1158
+ auto allocOp = dyn_cast<LocalAllocOp>(defOp);
1159
+ if (allocOp && loop.isDefinedOutsideOfLoop (allocOp.getSrc ()))
1160
+ allocOp->moveBefore (loop);
1142
1161
continue ;
1162
+ }
1163
+
1143
1164
if (auto allocOp = operand.getDefiningOp <LocalAllocOp>()) {
1144
1165
PartitionBuilder b (allocOp.getLoc (), allocOp);
1145
1166
auto store = b.createInto <LocalStoreOp>(*defPartition, std::nullopt,
1146
1167
allocOp.getSrc (), allocOp);
1168
+ auto fence = b.createInto <ttng::FenceAsyncSharedOp>(
1169
+ *defPartition, std::nullopt, /* bCluster=*/ false );
1147
1170
operandDefs.emplace_back (body.findAncestorOpInBlock (*store),
1148
1171
defPartition);
1172
+ operandDefs.emplace_back (body.findAncestorOpInBlock (*fence),
1173
+ defPartition);
1149
1174
allocOp->moveBefore (loop);
1150
1175
allocOp->removeAttr (kPartitionAttrName );
1151
1176
allocOp.getSrcMutable ().clear ();
@@ -1161,6 +1186,8 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
1161
1186
tmemAllocOp->removeAttr (kPartitionAttrName );
1162
1187
tmemAllocOp.getSrcMutable ().clear ();
1163
1188
tmemAllocOp.getResult ().setType (getAsMutable (tmemAllocOp.getType ()));
1189
+ } else if (defOp->hasTrait <OpTrait::MemDescViewTrait>()) {
1190
+ incomingOperands.push_back (defOp->getOperand (0 ));
1164
1191
}
1165
1192
}
1166
1193
@@ -1187,13 +1214,20 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
1187
1214
1188
1215
if (node.barPrev ) {
1189
1216
if (!isa<ttng::TMEMLoadOp>(node.op )) {
1190
- if (incrementPt && domOp->isBeforeInBlock (&*incrementPt->getPoint ()))
1217
+ // If the user precondition is defined after the MMA, we need to peel
1218
+ // the wait for the user.
1219
+ if (incrementPt && domOp->isBeforeInBlock (&*incrementPt->getPoint ()) &&
1220
+ domInfo.properlyDominates (mmaOp, userPred.getDefiningOp ())) {
1191
1221
b.restoreInsertionPoint (*incrementPt);
1192
- else
1222
+ Value bar = createSingleBufferView (b, node.barPrev , curIndex);
1223
+ b.createInto <ttng::WaitBarrierOp>(*partition, stages.lookup (node.op ),
1224
+ bar, curPhase, userPred);
1225
+ } else {
1193
1226
b.setInsertionPoint (domOp);
1194
- Value bar = createSingleBufferView (b, node.barPrev , curIndex);
1195
- b.createInto <ttng::WaitBarrierOp>(*partition, stages.lookup (node.op ),
1196
- bar, curPhase, userPred);
1227
+ Value bar = createSingleBufferView (b, node.barPrev , node.index );
1228
+ b.createInto <ttng::WaitBarrierOp>(*partition, stages.lookup (node.op ),
1229
+ bar, node.phase , userPred);
1230
+ }
1197
1231
} else {
1198
1232
b.setInsertionPoint (domOp);
1199
1233
if (isa<scf::IfOp>(domOp->getParentOp ()))
0 commit comments