@@ -118,7 +118,9 @@ Operation *getFirstUseOfPipelinedOp(SmallVector<Operation *> ops,
118118 auto [_firstUserStage, _firstUserCluster] = schedule[firstUser];
119119 if (_useStage < _firstUserStage ||
120120 (_useStage == _firstUserStage &&
121- schedule.clusters .isBefore (_useCluster, _firstUserCluster))) {
121+ schedule.clusters .isBefore (_useCluster, _firstUserCluster)) ||
122+ (_useStage == _firstUserStage && _useCluster == _firstUserCluster &&
123+ topLevelUser->isBeforeInBlock (firstUser))) {
122124 firstUser = topLevelUser;
123125 }
124126 }
@@ -214,6 +216,8 @@ Value createIncrementModulo(BuilderT &builder, Location loc, Value counter,
214216
215217void replaceAllUsesDominatedBy (Operation *domOp, Value newValue, Value oldValue,
216218 DominanceInfo &domInfo) {
219+ if (newValue == oldValue)
220+ return ;
217221 oldValue.replaceUsesWithIf (newValue, [&](OpOperand &use) {
218222 return domInfo.properlyDominates (domOp, use.getOwner ());
219223 });
@@ -232,15 +236,6 @@ static Value createAlloc(scf::ForOp &forOp, Operation *loadOp,
232236 loadOp->getLoc (), sharedEnc, distance);
233237}
234238
235- template <typename BuilderT, typename ... Args>
236- Operation *createWithStage (BuilderT &builder, Location loc, int stage,
237- CoarseSchedule::Cluster cluster, Args &&...args) {
238- Operation *op = builder.template create <ttg::AsyncCopyGlobalToLocalOp>(
239- loc, std::forward<Args>(args)...);
240-
241- return op;
242- }
243-
244239void createAsyncCopy (scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
245240 Value insertIdx, Value extractIdx,
246241 CoarseSchedule &schedule) {
@@ -874,12 +869,36 @@ std::pair<int, int> getTmemUseStageBounds(ttng::TMEMAllocOp alloc,
874869 return bounds;
875870}
876871
877- void createBarrierAndWaitOps (scf::ForOp forOp, CoarseSchedule &schedule,
878- ttng::MMAv5OpInterface mma,
879- ttng::TMEMAllocOp alloc, Value &phase,
880- Value &barrierIdx, int numStages) {
872+ // Create a predicate argument for the dist-1wait
873+ scf::ForOp prepLoopForDist1Wait (scf::ForOp forOp, CoarseSchedule &schedule,
874+ ttng::MMAv5OpInterface mma) {
875+ OpBuilderForStage builder (forOp, schedule);
876+ Location loc = mma.getLoc ();
877+ Value vFalse = builder.create <arith::ConstantIntOp>(loc, 0 , 1 );
878+
879+ // Create a predicate for the wait (start with false and change to true on the
880+ // first mma execution)
881+ scf::ForOp newForOp = replaceForOpWithNewSignature (builder, forOp, {vFalse});
882+ forOp.erase ();
883+ forOp = newForOp;
884+
885+ builder.setInsertionPointAfter (mma);
886+ builder.setStageCluster (schedule[mma]);
887+ Value vTrue = builder.create <arith::ConstantIntOp>(loc, 1 , 1 );
888+
889+ auto yieldOp = cast<scf::YieldOp>(forOp.getBody ()->getTerminator ());
890+ yieldOp.getResultsMutable ().append (vTrue); // predicate
891+ return forOp;
892+ }
893+
894+ scf::ForOp createBarrierAndWaitOps (scf::ForOp forOp, CoarseSchedule &schedule,
895+ ttng::MMAv5OpInterface mma,
896+ ttng::TMEMAllocOp alloc, int phaseArgIdx,
897+ int barrierIdxArgIdx, int numStages) {
881898 OpBuilderForStage builder (forOp, schedule);
882899 DominanceInfo domInfo (forOp);
900+ Value phase = forOp.getRegionIterArg (phaseArgIdx);
901+ Value barrierIdx = forOp.getRegionIterArg (barrierIdxArgIdx);
883902 Value zero = builder.create <arith::ConstantIntOp>(forOp.getLoc (), 0 , 32 );
884903 Value one = builder.create <arith::ConstantIntOp>(forOp.getLoc (), 1 , 32 );
885904 Value numStagesVal =
@@ -889,8 +908,11 @@ void createBarrierAndWaitOps(scf::ForOp forOp, CoarseSchedule &schedule,
889908 builder.setStageCluster (schedule[mma]);
890909 builder.setInsertionPoint (mma);
891910 Location loc = mma->getLoc ();
892- Value barrierSlice =
893- triton::createSingleBufferView (builder, barrierAlloc, barrierIdx);
911+ Value barrierSlice = barrierAlloc;
912+ if (numStages > 1 ) {
913+ barrierSlice =
914+ triton::createSingleBufferView (builder, barrierAlloc, barrierIdx);
915+ }
894916 mma.setBarrier (barrierSlice);
895917
896918 // List of buffers that may be used until wait completes
@@ -953,27 +975,85 @@ void createBarrierAndWaitOps(scf::ForOp forOp, CoarseSchedule &schedule,
953975 }
954976 }
955977
978+ std::optional<int > predicateArgIdx;
956979 for (auto wp : waitPoints) {
957980 builder.setStageCluster ({wp.stage , wp.cluster });
958981 builder.setInsertionPoint (wp.op );
959- builder.create <ttng::WaitBarrierOp>(loc, barrierSlice, phase, waitBuffers);
982+ SmallVector<Value> currWaitBuffers = waitBuffers;
983+ Value pred = nullptr ;
984+ if (!domInfo.properlyDominates (mma, wp.op )) {
985+ // Waits before the mma are special: they are dist-1 to the mma which
986+ // means two things:
987+ // 1. The wait should not be executed before the mma executes for the
988+ // first time
989+ // 2. The waitBuffers do not dominate this wait. We either need to hoist
990+ // the local allocs out of the loop, or have dist-1 buffer references.
991+ if (!predicateArgIdx) {
992+ forOp = prepLoopForDist1Wait (forOp, schedule, mma);
993+ predicateArgIdx = forOp.getInitArgs ().size () - 1 ;
994+ // HACK: Clear the wait buffers. This is meant to avoid having to have
995+ // dist-1 buffer references, or hoisting the local allocs out of the
996+ // loop. This works as long as the wait is in the same stage as the mma,
997+ // and is the reason for which we prevent pipelining the mma if there
998+ // are loads from the accumulator before the mma in a different stage.
999+ currWaitBuffers.clear ();
1000+ }
1001+ pred = forOp.getRegionIterArg (*predicateArgIdx);
1002+ }
1003+ Value barrierSlice = barrierAlloc;
1004+ if (numStages > 1 ) {
1005+ barrierSlice =
1006+ triton::createSingleBufferView (builder, barrierAlloc, barrierIdx);
1007+ }
1008+ builder.create <ttng::WaitBarrierOp>(loc, barrierSlice, phase, pred,
1009+ currWaitBuffers);
9601010 }
9611011
9621012 builder.setStageCluster (schedule[mma]);
9631013 builder.setInsertionPoint (forOp.getBody ()->getTerminator ());
964- Value barWrap;
965- barrierIdx = createIncrementModulo (builder, loc, barrierIdx, numStagesVal,
966- zero, one, &barWrap);
967- phase = builder.create <arith::SelectOp>(
968- loc, phase.getType (), barWrap,
969- builder.create <arith::XOrIOp>(loc, phase, one), phase);
1014+ Value newPhase = builder.create <arith::XOrIOp>(loc, phase, one);
1015+ Value newBarrierIdx = barrierIdx;
1016+ if (numStages > 1 ) {
1017+ Value barWrap;
1018+ newBarrierIdx = createIncrementModulo (builder, loc, barrierIdx,
1019+ numStagesVal, zero, one, &barWrap);
1020+ newPhase = builder.create <arith::SelectOp>(loc, phase.getType (), barWrap,
1021+ newPhase, phase);
1022+ }
1023+ if (predicateArgIdx) {
1024+ // If there is a wait predicate, we need to select the phase and barrierIdx
1025+ // based on the predicate
1026+ Value pred = forOp.getRegionIterArg (*predicateArgIdx);
1027+ newPhase = builder.create <arith::SelectOp>(loc, phase.getType (), pred,
1028+ newPhase, phase);
1029+ newBarrierIdx = builder.create <arith::SelectOp>(loc, phase.getType (), pred,
1030+ newBarrierIdx, barrierIdx);
1031+ }
1032+ replaceAllUsesDominatedBy (newPhase.getDefiningOp (), newPhase, phase, domInfo);
1033+ replaceAllUsesDominatedBy (newBarrierIdx.getDefiningOp (), newBarrierIdx,
1034+ barrierIdx, domInfo);
1035+
1036+ // If there is a dist-1 wait, we need to add a wait after the loop
1037+ if (predicateArgIdx) {
1038+ builder.setInsertionPointAfter (forOp);
1039+ Value barrierSlice = barrierAlloc;
1040+ Value barrierIdx = forOp.getResult (barrierIdxArgIdx);
1041+ Value phase = forOp.getResult (phaseArgIdx);
1042+ if (numStages > 1 ) {
1043+ barrierSlice =
1044+ triton::createSingleBufferView (builder, barrierAlloc, barrierIdx);
1045+ }
1046+ builder.create <ttng::WaitBarrierOp>(loc, barrierSlice, phase,
1047+ forOp.getResult (*predicateArgIdx));
1048+ }
1049+ return forOp;
9701050}
9711051
9721052void multibufferTensorMemory (scf::ForOp forOp, CoarseSchedule &schedule,
973- ttng::TMEMAllocOp alloc, Value &bufIdx ,
974- int bufIdxArgIdx, int tmemUseNumStages) {
1053+ ttng::TMEMAllocOp alloc, int bufIdxArgIdx ,
1054+ int tmemUseNumStages) {
9751055 DominanceInfo domInfo (forOp);
976-
1056+ Value bufIdx = forOp. getRegionIterArg (bufIdxArgIdx);
9771057 SmallVector<std::pair<Operation *, Value>> bufIdxDefs;
9781058 auto getCurrBufIdx = [&](Operation *op) {
9791059 for (auto [_op, _val] : llvm::reverse (bufIdxDefs)) {
@@ -1077,17 +1157,16 @@ void multibufferTensorMemory(scf::ForOp forOp, CoarseSchedule &schedule,
10771157 " accumulator, and the mma uses the accumulator all the time." );
10781158 }
10791159 alloc->erase ();
1080- bufIdx = bufIdxDefs.back ().second ;
1160+ Value newBufIdx = bufIdxDefs.back ().second ;
1161+ replaceAllUsesDominatedBy (newBufIdx.getDefiningOp (), newBufIdx, bufIdx,
1162+ domInfo);
10811163}
10821164
10831165scf::ForOp lowerMMA (ttng::MMAv5OpInterface mma, scf::ForOp forOp,
10841166 CoarseSchedule &schedule) {
10851167 auto isLoadPipelineable = [&](Operation *op) {
10861168 return schedule[mma].first > schedule[op].first ;
10871169 };
1088- if (!mmaHasPipelineableOperands (mma, forOp, isLoadPipelineable)) {
1089- return forOp;
1090- }
10911170 auto alloc = mma.getAccumulator ().getDefiningOp <ttng::TMEMAllocOp>();
10921171 if (!alloc) {
10931172 return forOp;
@@ -1099,7 +1178,8 @@ scf::ForOp lowerMMA(ttng::MMAv5OpInterface mma, scf::ForOp forOp,
10991178 int tmemUseNumStages =
11001179 tmemUseStageBounds.second - tmemUseStageBounds.first + 1 ;
11011180 int waitNumStages = tmemUseStageBounds.second - schedule[mma].first + 1 ;
1102- if (waitNumStages == 1 && !hasAccReadModifyWrite (mma, forOp)) {
1181+ if (waitNumStages == 1 && !hasAccReadModifyWrite (mma, forOp) &&
1182+ mmaHasPipelineableOperands (mma, forOp, isLoadPipelineable)) {
11031183 // Overlap the mma with itself, even if there is no use of the accumulator
11041184 // after the mma
11051185 waitNumStages = 2 ;
@@ -1112,35 +1192,40 @@ scf::ForOp lowerMMA(ttng::MMAv5OpInterface mma, scf::ForOp forOp,
11121192 // Add arguments to the forOp
11131193 unsigned newOperandIndex = forOp.getInitArgs ().size ();
11141194 SmallVector<Value> newOperands = {
1115- zero, // phase
1116- zero, // barrierIdx
1117- minusOne, // bufIdx
1195+ zero, // phase
1196+ zero, // barrierIdx
11181197 };
1198+ if (tmemUseNumStages > 1 ) {
1199+ newOperands.push_back (minusOne); // bufIdx
1200+ }
11191201 scf::ForOp newForOp =
11201202 replaceForOpWithNewSignature (builder, forOp, newOperands);
11211203 forOp.erase ();
11221204 forOp = newForOp;
11231205
1124- Value phase = forOp.getRegionIterArg (newOperandIndex + 0 );
1125- Value barrierIdx = forOp.getRegionIterArg (newOperandIndex + 1 );
1126- Value bufIdx = forOp.getRegionIterArg (newOperandIndex + 2 );
1206+ int phaseArgIdx = newOperandIndex + 0 ;
1207+ int barrierIdxArgIdx = newOperandIndex + 1 ;
1208+ int bufIdxArgIdx = newOperandIndex + 2 ;
1209+ Value phase = forOp.getRegionIterArg (phaseArgIdx);
1210+ Value barrierIdx = forOp.getRegionIterArg (barrierIdxArgIdx);
11271211
1128- if (waitNumStages > 1 ) {
1129- createBarrierAndWaitOps (forOp, schedule, mma, alloc, phase, barrierIdx,
1130- waitNumStages);
1212+ SmallVector<Value> newYieldOperands = {phase, barrierIdx};
1213+ if (tmemUseNumStages > 1 ) {
1214+ Value bufIdx = forOp.getRegionIterArg (bufIdxArgIdx);
1215+ newYieldOperands.push_back (bufIdx);
11311216 }
1217+ cast<scf::YieldOp>(forOp.getBody ()->getTerminator ())
1218+ .getResultsMutable ()
1219+ .append (newYieldOperands);
1220+
1221+ forOp = createBarrierAndWaitOps (forOp, schedule, mma, alloc, phaseArgIdx,
1222+ barrierIdxArgIdx, waitNumStages);
11321223
11331224 if (tmemUseNumStages > 1 ) {
1134- multibufferTensorMemory (forOp, schedule, alloc, bufIdx, newOperandIndex + 2 ,
1225+ multibufferTensorMemory (forOp, schedule, alloc, bufIdxArgIdx ,
11351226 tmemUseNumStages);
11361227 }
11371228
1138- SmallVector<Value> newYieldOperands;
1139- newYieldOperands.push_back (phase);
1140- newYieldOperands.push_back (barrierIdx);
1141- newYieldOperands.push_back (bufIdx);
1142- appendToForOpYield (forOp, newYieldOperands);
1143-
11441229 return forOp;
11451230}
11461231
0 commit comments