@@ -573,7 +573,7 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
573573 for (int i = 0 ; i < nodes.size (); ++i) {
574574 Node &cur = nodes[i];
575575 Node &next = nodes[(i + 1 ) % nodes.size ()];
576- if (!samePartition (inBody ( cur.op ), inBody ( next.op ) )) {
576+ if (!samePartition (cur.op , next.op )) {
577577 cur.barNext = createBarrierAlloc (loop, numMmaStages);
578578 next.barPrev = cur.barNext ;
579579 }
@@ -616,6 +616,7 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
616616 continue ;
617617 b.setInsertionPoint (node.op );
618618 Value view = createSingleBufferView (b, allocOp, node.index );
619+ b.assignPartition (view.getDefiningOp (), *partitions.getPartition (node.op ));
619620 if (auto storeOp = dyn_cast<ttng::TMEMStoreOp>(node.op )) {
620621 storeOp.getDstMutable ().assign (view);
621622 storeOp.getDepMutable ().clear ();
@@ -671,7 +672,6 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
671672 Operation *defOp = operand.getDefiningOp ();
672673 if (!defOp || loop.isDefinedOutsideOfLoop (operand))
673674 continue ;
674- defOp = inBody (defOp);
675675
676676 if (partitions.isInRootPartition (defOp)) {
677677 // If the MMA operand is coming from outside the loop, move the alloc out.
@@ -717,7 +717,7 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
717717 }
718718
719719 for (Node &node : nodes) {
720- Partition *partition = partitions.getPartition (inBody ( node.op ) );
720+ Partition *partition = partitions.getPartition (node.op );
721721 PartitionBuilder b (node.op ->getLoc (), loop);
722722
723723 SmallVector<Operation *> defs;
@@ -746,11 +746,13 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
746746 domInfo.properlyDominates (mmaOp, userPred.getDefiningOp ())) {
747747 b.restoreInsertionPoint (*incrementPt);
748748 Value bar = createSingleBufferView (b, node.barPrev , curIndex);
749+ b.assignPartition (bar.getDefiningOp (), *partition);
749750 b.createInto <ttng::WaitBarrierOp>(*partition, nodeStageCluster, bar,
750751 curPhase, userPred);
751752 } else {
752753 b.setInsertionPoint (domOp);
753754 Value bar = createSingleBufferView (b, node.barPrev , node.index );
755+ b.assignPartition (bar.getDefiningOp (), *partition);
754756 b.createInto <ttng::WaitBarrierOp>(*partition, nodeStageCluster, bar,
755757 node.phase , userPred);
756758 }
@@ -759,6 +761,7 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
759761 if (isa<scf::IfOp>(domOp->getParentOp ()) && accIsMultiBuffered)
760762 b.setInsertionPointToStart (domOp->getBlock ());
761763 Value bar = createSingleBufferView (b, node.barPrev , node.index );
764+ b.assignPartition (bar.getDefiningOp (), *partition);
762765 b.createInto <ttng::WaitBarrierOp>(*partition, nodeStageCluster, bar,
763766 node.phase );
764767 }
@@ -767,13 +770,15 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
767770 if (mmaOp == node.op ) {
768771 b.setInsertionPoint (mmaOp);
769772 Value bar = createSingleBufferView (b, node.barNext , node.index );
773+ b.assignPartition (bar.getDefiningOp (), *partitions.getPartition (mmaOp));
770774 mmaOp.addCompletionBarrier (bar, userPred);
771775 mmaOp.setIsAsync (true );
772776 } else {
773777 b.setInsertionPointAfter (lastOp);
774778 if (isa<scf::IfOp>(lastOp->getParentOp ()) && accIsMultiBuffered)
775779 b.setInsertionPoint (lastOp->getBlock ()->getTerminator ());
776780 Value bar = createSingleBufferView (b, node.barNext , node.index );
781+ b.assignPartition (bar.getDefiningOp (), *partition);
777782 b.createInto <ttng::ArriveBarrierOp>(*partition, nodeStageCluster, bar,
778783 1 );
779784 }
@@ -799,20 +804,26 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
799804 StageCluster srcStageCluster = getStageCluster (domOp);
800805 b.setInsertionPoint (domOp);
801806 Value emptyView = createSingleBufferView (b, emptyBar, index);
807+ b.assignPartition (emptyView.getDefiningOp (), *partition);
802808 b.createInto <ttng::WaitBarrierOp>(*partition, srcStageCluster, emptyView,
803809 phase);
804810
805811 b.setInsertionPointAfter (lastOp);
806812 Value readyView = createSingleBufferView (b, readyBar, index);
813+ b.assignPartition (readyView.getDefiningOp (), *partition);
807814 b.createInto <ttng::ArriveBarrierOp>(*partition, srcStageCluster, readyView,
808815 1 );
809816
810817 b.setInsertionPoint (mmaOp);
811818 Value readyView2 = createSingleBufferView (b, readyBar, index);
819+ b.assignPartition (readyView2.getDefiningOp (),
820+ *partitions.getPartition (mmaOp));
812821 b.createInto <ttng::WaitBarrierOp>(*partitions.getPartition (mmaOp),
813822 getStageCluster (mmaOp), readyView2,
814823 phase);
815824 Value emptyView2 = createSingleBufferView (b, emptyBar, index);
825+ b.assignPartition (emptyView2.getDefiningOp (),
826+ *partitions.getPartition (mmaOp));
816827 mmaOp.addCompletionBarrier (emptyView2, b.boolCst (true ));
817828 mmaOp.setIsAsync (true );
818829 }
0 commit comments