@@ -35,12 +35,17 @@ void annotateWithPipelineStage(IRRewriter &builder, Operation *op, int stage) {
3535 IntegerAttr::get (builder.getI32Type (), stage));
3636}
3737
38+ int getPipelineStage (Operation *op) {
39+ return op->getAttrOfType <IntegerAttr>(kPipelineStageAttrName ).getInt ();
40+ }
41+
3842struct MMAInfo {
3943 struct AccOverridePoint {
4044 Operation *op;
4145 Value condition = nullptr ;
4246 Value initValue = nullptr ;
4347 int distance = 0 ;
48+ bool isFlag = false ;
4449 };
4550
4651 ttng::TMEMAllocOp accAlloc; // Directly precedes the dot, allocating tmem
@@ -136,6 +141,7 @@ std::optional<MMAInfo::AccOverridePoint>
136141getAccOverridePointInLoop (scf::ForOp forOp, ttng::TMEMAllocOp accUse,
137142 ttng::TMEMLoadOp accDef) {
138143 MMAInfo::AccOverridePoint accOverridePoint;
144+ accOverridePoint.isFlag = false ;
139145 DenseSet<Value> seen;
140146 Value v = accUse.getSrc ();
141147 if (v == nullptr ) {
@@ -219,6 +225,7 @@ getAccUseFlagFalseInLoop(scf::ForOp forOp, Value useAccFlagUse) {
219225
220226 IRRewriter builder (v.getDefiningOp ()->getNextNode ());
221227 MMAInfo::AccOverridePoint accOverridePoint;
228+ accOverridePoint.isFlag = true ;
222229 accOverridePoint.distance = dist;
223230 Location loc = v.getDefiningOp ()->getLoc ();
224231 auto vTrue =
@@ -374,9 +381,12 @@ void updateAccUsesInLoop(IRRewriter &builder, scf::ForOp forOp, MMAInfo &info,
374381 }
375382 auto load = builder.create <ttng::TMEMLoadOp>(
376383 domOp->getLoc (), info.accLoad .getType (), extractSlice);
384+ // If accumulator is multi-buffered, it is implicit that we put the load
385+ // in the last stage.
386+ int pipelineStage = info.accIsMultiBuffered ? numStages - 1 : 0 ;
377387 annotateWithPipelineStage (
378388 builder, forOp.getBody ()->findAncestorOpInBlock (*load.getOperation ()),
379- numStages - 1 );
389+ pipelineStage );
380390 for (auto user : directUses) {
381391 user->replaceUsesOfWith (info.accLoad , load);
382392 }
@@ -574,12 +584,45 @@ void createBarrierAndWaitOps(IRRewriter &builder, scf::ForOp forOp,
574584 info.barrierIdx = newBarrierIdx;
575585 annotateWithPipelineStage (builder, info.barrierIdx .getDefiningOp (), 0 );
576586
587+ Value originalPhase = info.phase ;
577588 Value newPhase = builder.create <arith::SelectOp>(
578589 loc, info.phase .getType (), barWrap,
579590 builder.create <arith::XOrIOp>(loc, info.phase , one), info.phase );
580591 replaceAllUsesDominatedBy (newPhase.getDefiningOp (), newPhase, info.phase );
581592 info.phase = newPhase;
582593 annotateWithPipelineStage (builder, info.phase .getDefiningOp (), 0 );
594+
595+ // We need to add a barrier before load from the accumulator, if it is in the
596+ // same stage as the dot.
597+ ttng::TMEMLoadOp tmemLoad = nullptr ;
598+ SmallVector<Operation *> users = {info.accAlloc ->getUsers ().begin (),
599+ info.accAlloc ->getUsers ().end ()};
600+ while (!users.empty ()) {
601+ auto user = users.pop_back_val ();
602+ if (isa<ttg::MemDescSubviewOp>(user)) {
603+ users.append (user->getUsers ().begin (), user->getUsers ().end ());
604+ }
605+ if (isa<ttng::TMEMLoadOp>(user) && forOp->isAncestor (user)) {
606+ if (tmemLoad) {
607+ assert (tmemLoad == cast<ttng::TMEMLoadOp>(user) &&
608+ " Should have only one tmem load from the accumulator" );
609+ }
610+ tmemLoad = cast<ttng::TMEMLoadOp>(user);
611+ }
612+ }
613+ if (tmemLoad) {
614+ int loadStage =
615+ getPipelineStage (forOp.getBody ()->findAncestorOpInBlock (*tmemLoad));
616+ int mmaOpStage = getPipelineStage (mmaOp);
617+ if (loadStage == mmaOpStage) {
618+ builder.setInsertionPoint (tmemLoad);
619+ auto barrier =
620+ builder.create <ttng::WaitBarrierOp>(loc, barrierSlice, originalPhase);
621+ annotateWithPipelineStage (
622+ builder, forOp.getBody ()->findAncestorOpInBlock (*barrier),
623+ mmaOpStage);
624+ }
625+ }
583626}
584627
585628bool isSafeToPipeline (ttng::TCGen5MMAScaledOp scaledDot) {
@@ -684,17 +727,33 @@ FailureOr<scf::ForOp> preProcessLoopForTC05MMAPipelining(scf::ForOp forOp,
684727 continue ;
685728 }
686729
730+ SmallVector<Operation *> accUses = getDirectAccUses (accLoad);
731+ DominanceInfo domOpInfo (forOp);
732+ Operation *newAccLoadInsertPoint =
733+ findNearestCommonDominator (accUses, domOpInfo);
687734 // Check pipelining and multi-buffering constraints
688- // 1. If the acc is used by an op in the loop (other than the dot) it
689- // requires multi-buffering to pipeline, as different stages cannot operate
690- // on the same buffer.
691- bool requiresMultiBuffer = !getDirectAccUses (accLoad).empty ();
735+ // 1. Really needs multibuffering - if the acc is used unconditionally in
736+ // the loop, or under different conditions. If we cannot multibuffer in this
737+ // case, we may as well not pipeline at all, as we will have to wait after
738+ // the dot in every loop iteration.
739+ scf::IfOp topLevelIf =
740+ newAccLoadInsertPoint
741+ ? dyn_cast<scf::IfOp>(forOp.getBody ()->findAncestorOpInBlock (
742+ *newAccLoadInsertPoint))
743+ : nullptr ;
744+ bool requiresMultiBuffer = accUses.size () > 0 && !topLevelIf;
745+ // If we override the acc in the loop, it is generally hard to handle it
746+ // without multibuffering. We make an exception if it not a physical
747+ // override of a value, but just setting a flag that acc is not used. In
748+ // this case we don't need different buffer to store init value.
749+ requiresMultiBuffer |=
750+ accOverridePoint.has_value () && !accOverridePoint->isFlag ;
692751
693752 // 2. If the acc is not owerwritten in the loop (by op other than the dot),
694753 // it cannot be multi-buffered. This is because the overwrite is the only
695754 // way to initialize next buffer without incurring a copy.
696- bool canMultiBuffer = accOverridePoint.has_value ();
697-
755+ bool canMultiBuffer = accOverridePoint.has_value () &&
756+ ! mlir::triton::getDisallowAccMultiBuffer (forOp);
698757 if (requiresMultiBuffer && !canMultiBuffer) {
699758 continue ;
700759 }
@@ -703,7 +762,7 @@ FailureOr<scf::ForOp> preProcessLoopForTC05MMAPipelining(scf::ForOp forOp,
703762 .accLoad = accLoad,
704763 .accDef = accOverridePoint,
705764 .yieldArgNo = yieldArgNo,
706- .accIsMultiBuffered = requiresMultiBuffer };
765+ .accIsMultiBuffered = canMultiBuffer };
707766
708767 builder.setInsertionPoint (forOp);
709768 Value zero = builder.create <arith::ConstantIntOp>(forOp.getLoc (), 0 , 32 );
0 commit comments