@@ -121,7 +121,7 @@ static Operation *getFirstUseOfPipelinedLoad(Operation *loadOp) {
121121static int createAsyncCopy (scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
122122 Value insertIdx, Value extractIdx,
123123 llvm::MapVector<Operation *, LoadInfo> &loadToInfo,
124- int numStages, int maxClusterId) {
124+ int maxClusterId) {
125125 int retCode = -1 ;
126126 OpBuilderWithStage builder (forOp);
127127 auto opPair = tt::getStageCluster (loadOp);
@@ -234,8 +234,7 @@ static void
234234createTMAAsyncCopy (scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp,
235235 Value alloc, Value insertIdx, Value extractIdx,
236236 Value barrier, Operation *waitOp, Value phase,
237- llvm::MapVector<Operation *, LoadInfo> &loadToInfo,
238- int numStages) {
237+ llvm::MapVector<Operation *, LoadInfo> &loadToInfo) {
239238 assert (phase && " Phase value is required for TMA async copy." );
240239 OpBuilderWithStage builder (forOp);
241240 auto [stage, clusterId] = tt::getStageCluster (loadOp);
@@ -585,21 +584,28 @@ static Value createBarrierAlloc(scf::ForOp &forOp, unsigned distance) {
585584 return barrierAlloc;
586585}
587586
587+ struct StageGroup {
588+ Value insertIdx;
589+ Value extractIdx;
590+ Value phase;
591+ bool hasTMALoad = false ;
592+ };
588593struct AsyncLoad {
589- AsyncLoad (Operation *loadOp, Value alloc) : loadOp(loadOp), alloc(alloc) {}
590594 Operation *loadOp;
591595 Value alloc;
592596 Value barrier;
593597 Operation *waitOp = nullptr ;
594598 int firstUseStage, firstUseCluster;
595599 bool isTMALoad = false ;
600+ int numBuffers = 0 ;
596601};
597602
598603// Create barriers and wait ops for the async loads. Barriers may be shared by
599- // multiple loads is the schedule allows it.
604+ // multiple loads if the schedule allows it.
600605static void createTMABarrierAndWait (
601- scf::ForOp &forOp, SmallVector<AsyncLoad> &asyncLoads, Value insertIdx,
602- Value extractIdx, Value phase, int numBuffers, SmallVector<Value> &barriers,
606+ scf::ForOp &forOp, SmallVector<AsyncLoad> &asyncLoads,
607+ SmallVector<Value> &barriers,
608+ const llvm::MapVector<int , StageGroup> &stageGroups,
603609 const llvm::MapVector<Operation *, LoadInfo> &loadToInfo) {
604610 llvm::SmallDenseMap<Operation *, AsyncLoad *> loadToAsyncLoad;
605611 for (AsyncLoad &asyncLoad : asyncLoads) {
@@ -639,12 +645,15 @@ static void createTMABarrierAndWait(
639645 };
640646 addToGroup (&asyncLoad);
641647 Operation *nextOp = asyncLoad.loadOp ->getNextNode ();
648+ int numBuffers = asyncLoad.numBuffers ;
642649 while (nextOp) {
643650 if (users.count (nextOp) || visited.count (nextOp))
644651 break ;
645652 if (isa<tt::ExperimentalDescriptorLoadOp>(nextOp)) {
646653 auto it = loadToAsyncLoad.find (nextOp);
647654 if (it != loadToAsyncLoad.end () && it->second ->isTMALoad ) {
655+ if (it->second ->numBuffers != numBuffers)
656+ break ;
648657 if (group.size () > 0 &&
649658 sameStageCluster (group[0 ]->loadOp , it->second ->loadOp ))
650659 addToGroup (it->second );
@@ -659,6 +668,8 @@ static void createTMABarrierAndWait(
659668 // load.
660669 for (SmallVector<AsyncLoad *> &group : loadGroups) {
661670 int sizeInBytes = 0 ;
671+ int numBuffers = group[0 ]->numBuffers ;
672+ const StageGroup &stageGroup = stageGroups.find (numBuffers)->second ;
662673 for (AsyncLoad *asyncLoad : group) {
663674 auto tensorTy =
664675 cast<RankedTensorType>(asyncLoad->loadOp ->getResult (0 ).getType ());
@@ -682,7 +693,7 @@ static void createTMABarrierAndWait(
682693 builder.setInsertionPoint (group[0 ]->loadOp );
683694 Value barrier = builder.createWithStage <ttg::MemDescSubviewOp>(
684695 loc, stage, cluster, barrierTy, barrierAlloc,
685- ArrayRef<Value>({insertIdx}));
696+ ArrayRef<Value>({stageGroup. insertIdx }));
686697 Value pred = builder.createWithStage <arith::ConstantIntOp>(loc, stage,
687698 cluster, 1 , 1 );
688699 Operation *expect = builder.createWithStage <ttng::BarrierExpectOp>(
@@ -691,10 +702,10 @@ static void createTMABarrierAndWait(
691702 builder.setInsertionPointAfter (group.back ()->loadOp );
692703 Value barrierViewWait = builder.createWithStage <ttg::MemDescSubviewOp>(
693704 loc, group[0 ]->firstUseStage , group[0 ]->firstUseCluster , barrierTy,
694- barrierAlloc, ArrayRef<Value>({extractIdx}));
705+ barrierAlloc, ArrayRef<Value>({stageGroup. extractIdx }));
695706 Operation *wait = builder.createWithStage <ttng::WaitBarrierOp>(
696707 loc, group[0 ]->firstUseStage , group[0 ]->firstUseCluster ,
697- barrierViewWait, phase);
708+ barrierViewWait, stageGroup. phase );
698709 // Update the async loads info.
699710 for (AsyncLoad *asyncLoad : group) {
700711 asyncLoad->barrier = barrier;
@@ -855,46 +866,47 @@ static SmallVector<Value>
855866createAsyncOps (scf::ForOp &forOp,
856867 llvm::MapVector<Operation *, LoadInfo> &loadToInfo,
857868 SmallVector<Value> &barriers, int numStages) {
858- // Calculate the number of buffers needed for each load.
859- // TODO pawel: we could do more fine-grained allocation here and
860- // allocate only the number of buffers that specific loads need.
861- // Instead, we allocate the maximum number of buffers needed by any load.
862- int numBuffers =
863- llvm::max_element (llvm::make_second_range (loadToInfo), [](auto &lhs,
864- auto &rhs) {
865- return lhs.distToUse < rhs.distToUse ;
866- })->distToUse ;
867- bool hasMMAV3 = llvm::any_of (loadToInfo, [](auto &kv) {
868- return kv.second .isMMAv3Shared || kv.second .isMMAv3Registers ;
869- });
870- if (hasMMAV3) {
871- // For MMAv3, we need an extra buffer as this is assumed in the wgmma
872- // pipelining post-processing.
873- numBuffers++;
874- };
875-
876869 llvm::MapVector<Operation *, Value> tmaBufferMapping;
877870 if (failed (allocTMABuffers (forOp, tmaBufferMapping, numStages))) {
878871 llvm_unreachable (" TMA pipelining failed" );
879872 }
880873
874+ // Each group of loads/allocs with the same number of buffers (and stages)
875+ // will share the indices and barriers.
876+
881877 SmallVector<AsyncLoad> asyncLoads;
882878 SmallVector<Value> allocs;
883- bool hasTMALoad = false ;
879+ llvm::MapVector<int , StageGroup> stageGroups;
880+
884881 for (auto &[loadOp, info] : loadToInfo) {
882+ AsyncLoad asyncLoad = {.loadOp = loadOp};
883+ bool isTMALoad = false ;
884+ int numBuffers = info.distToUse ;
885+ // For MMAv3, we need an extra buffer as this is assumed in the wgmma
886+ // pipelining post-processing.
887+ if (info.isMMAv3Shared || info.isMMAv3Registers ) {
888+ ++numBuffers;
889+ }
890+ if (isa<tt::ExperimentalDescriptorLoadOp>(loadOp)) {
891+ isTMALoad = true ;
892+ asyncLoad.isTMALoad = isTMALoad;
893+ }
885894 assert (info.sharedEncoding && " LoadOp shared encoding not defined." );
886895 Value alloc = createAlloc (forOp, loadOp, info.sharedEncoding , numBuffers);
887896 assert (alloc && " Failed to create alloc for the async load." );
888897 allocs.push_back (alloc);
889- asyncLoads.emplace_back (loadOp, alloc);
890- if (isa<tt::ExperimentalDescriptorLoadOp>(loadOp)) {
891- hasTMALoad = true ;
892- asyncLoads.back ().isTMALoad = true ;
893- }
898+ asyncLoad.alloc = alloc;
899+
894900 auto *firstUse = getFirstUseOfPipelinedLoad (loadOp);
895901 auto [firstUseStage, firstUseCluster] = tt::getStageCluster (firstUse);
896- asyncLoads.back ().firstUseStage = firstUseStage;
897- asyncLoads.back ().firstUseCluster = firstUseCluster;
902+ asyncLoad.firstUseStage = firstUseStage;
903+ asyncLoad.firstUseCluster = firstUseCluster;
904+ asyncLoad.numBuffers = numBuffers;
905+ stageGroups.insert ({numBuffers, {}});
906+ if (isTMALoad) {
907+ stageGroups[numBuffers].hasTMALoad = true ;
908+ }
909+ asyncLoads.push_back (asyncLoad);
898910 }
899911
900912 IRRewriter builder (forOp.getContext ());
@@ -908,41 +920,34 @@ createAsyncOps(scf::ForOp &forOp,
908920 Value minusOne = builder.create <arith::ConstantIntOp>(loc, -1 , 32 );
909921 Value zero = builder.create <arith::ConstantIntOp>(loc, 0 , 32 );
910922 Value one = builder.create <arith::ConstantIntOp>(loc, 1 , 32 );
911- Value insertIdx = minusOne;
912- Value extractIdx = minusOne;
913- Value phase = Value ();
914- Value numBuffersVal =
915- builder.create <arith::ConstantIntOp>(loc, numBuffers, 32 );
916923 SmallVector<Value> newOperands;
917- newOperands.push_back (insertIdx);
918- newOperands.push_back (extractIdx);
919- if (hasTMALoad) {
920- // A single barrier arrival sequence is a "phase" and two phases can
921- // overlap, provided the phases are differentiated with an alternating
922- // boolean value.
923- phase = builder.create <arith::ConstantIntOp>(loc, 0 , 32 );
924- newOperands.push_back (phase);
924+ unsigned newOperandIndex = forOp.getBody ()->getNumArguments ();
925+ for (auto [_, stageGroup] : stageGroups) {
926+ newOperands.push_back (minusOne); // insertIdx
927+ newOperands.push_back (minusOne); // extractIdx
928+ if (stageGroup.hasTMALoad ) {
929+ // A single barrier arrival sequence is a "phase" and two phases can
930+ // overlap, provided the phases are differentiated with an alternating
931+ // boolean value.
932+ newOperands.push_back (zero); // phase
933+ }
925934 }
926935 // Also create one counter per TMA buffer. This allows the descriptors to be
927936 // updated independently without needing to write duplicate of existing tma
928937 // descriptors.
938+ unsigned tmaCounterArgsStartIdx = newOperandIndex + newOperands.size ();
929939 for (int i = 0 ; i < tmaBufferMapping.size (); ++i) {
930940 newOperands.push_back (zero);
931941 }
932942
933- unsigned newOperandIndex = forOp.getBody ()->getNumArguments ();
934943 // Patch the loop to add the new loop carried dependencies.
935944 scf::ForOp newForOp =
936945 replaceForOpWithNewSignature (builder, forOp, newOperands);
937946 forOp.erase ();
938947 forOp = newForOp;
939- insertIdx = newForOp.getBody ()->getArgument (newOperandIndex);
940- extractIdx = newForOp.getBody ()->getArgument (newOperandIndex + 1 );
941- if (phase) {
942- phase = newForOp.getBody ()->getArgument (newOperandIndex + 2 );
943- }
948+
944949 auto tmaCounters = ArrayRef<BlockArgument>(newForOp.getBody ()->getArguments ())
945- .slice (newOperandIndex + (phase ? 3 : 2 ) );
950+ .slice (tmaCounterArgsStartIdx );
946951
947952 // Update yield op with temporary yield values
948953 auto forYield = cast<scf::YieldOp>(newForOp.getBody ()->getTerminator ());
@@ -956,44 +961,70 @@ createAsyncOps(scf::ForOp &forOp,
956961 }
957962 tmaBufferMapping.clear ();
958963
959- // FIXME: loads can be in different (stage, cluster)
960- // Create two counters for the insert and extract indices to avoid creating
961- // long liverange.
962- builder.setInsertionPoint (newForOp.getBody (), newForOp.getBody ()->begin ());
963- insertIdx = builder.create <arith::AddIOp>(loc, insertIdx, one);
964- Value cndIns = builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
965- insertIdx, numBuffersVal);
966- insertIdx = builder.create <arith::SelectOp>(loc, cndIns, insertIdx, zero);
967-
968- extractIdx = builder.create <arith::AddIOp>(loc, extractIdx, one);
969- Value cndExt = builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
970- extractIdx, numBuffersVal);
971- extractIdx = builder.create <arith::SelectOp>(loc, cndExt, extractIdx, zero);
972- if (phase) {
973- Value nextPhase = builder.create <arith::XOrIOp>(loc, phase, one);
974- phase = builder.create <arith::SelectOp>(loc, cndExt, phase, nextPhase);
964+ builder.setInsertionPoint (forOp);
965+ loc = forOp.getLoc ();
966+ int argIdx = newOperandIndex;
967+ for (auto &[numBuffers, stageGroup] : stageGroups) {
968+ Value insertIdx = newForOp.getBody ()->getArgument (argIdx);
969+ argIdx++;
970+ Value extractIdx = newForOp.getBody ()->getArgument (argIdx);
971+ argIdx++;
972+ Value phase = nullptr ;
973+ if (stageGroup.hasTMALoad ) {
974+ phase = newForOp.getBody ()->getArgument (argIdx);
975+ argIdx++;
976+ }
977+
978+ // Create two counters for the insert and extract indices to avoid creating
979+ // long liverange.
980+ builder.setInsertionPoint (newForOp.getBody (), newForOp.getBody ()->begin ());
981+
982+ Value numBuffersVal =
983+ builder.create <arith::ConstantIntOp>(loc, numBuffers, 32 );
984+ insertIdx = builder.create <arith::AddIOp>(loc, insertIdx, one);
985+ Value cndIns = builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
986+ insertIdx, numBuffersVal);
987+ insertIdx = builder.create <arith::SelectOp>(loc, cndIns, insertIdx, zero);
988+ stageGroup.insertIdx = insertIdx;
989+
990+ extractIdx = builder.create <arith::AddIOp>(loc, extractIdx, one);
991+ // Duplicate the constant to keep it from being carried across loops.
992+ numBuffersVal = builder.create <arith::ConstantIntOp>(loc, numBuffers, 32 );
993+ Value cndExt = builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
994+ extractIdx, numBuffersVal);
995+ extractIdx = builder.create <arith::SelectOp>(loc, cndExt, extractIdx, zero);
996+ stageGroup.extractIdx = extractIdx;
997+ if (phase) {
998+ Value nextPhase = builder.create <arith::XOrIOp>(loc, phase, one);
999+ phase = builder.create <arith::SelectOp>(loc, cndExt, phase, nextPhase);
1000+ stageGroup.phase = phase;
1001+ }
9751002 }
976- createTMABarrierAndWait (forOp, asyncLoads, insertIdx, extractIdx, phase,
977- numBuffers, barriers, loadToInfo);
1003+ createTMABarrierAndWait (forOp, asyncLoads, barriers, stageGroups, loadToInfo);
9781004
9791005 auto [_, maxClusterId] = tt::getMinMaxCluster (forOp);
9801006 for (AsyncLoad &asyncLoad : asyncLoads) {
1007+ auto [insertIdx, extractIdx, phase, _] = stageGroups[asyncLoad.numBuffers ];
9811008 if (auto loadOp = dyn_cast<tt::LoadOp>(asyncLoad.loadOp )) {
9821009 createAsyncCopy (forOp, loadOp, asyncLoad.alloc , insertIdx, extractIdx,
983- loadToInfo, numStages, maxClusterId);
1010+ loadToInfo, maxClusterId);
9841011 } else {
9851012 auto descLoad = cast<tt::ExperimentalDescriptorLoadOp>(asyncLoad.loadOp );
9861013 createTMAAsyncCopy (forOp, descLoad, asyncLoad.alloc , insertIdx,
9871014 extractIdx, asyncLoad.barrier , asyncLoad.waitOp , phase,
988- loadToInfo, numStages );
1015+ loadToInfo);
9891016 }
9901017 }
991- // Patch the yield with the updated counters.
992- forYield.setOperand (newOperandIndex + -1 , insertIdx);
993- forYield.setOperand (newOperandIndex + 0 , extractIdx);
994- if (phase) {
995- forYield.setOperand (newOperandIndex + 1 , phase);
1018+ // Patch the yield with the updated counters. Subtract to account for the loop
1019+ // counter.
1020+ argIdx = newOperandIndex - 1 ;
1021+ for (auto &[numBuffers, stageGroup] : stageGroups) {
1022+ forYield.setOperand (argIdx++, stageGroup.insertIdx );
1023+ forYield.setOperand (argIdx++, stageGroup.extractIdx );
1024+ if (stageGroup.phase )
1025+ forYield.setOperand (argIdx++, stageGroup.phase );
9961026 }
1027+ assert (argIdx + 1 == tmaCounterArgsStartIdx);
9971028
9981029 tt::CoarseSchedule coarseSchedule (numStages);
9991030 coarseSchedule.deSerialize (forOp);
0 commit comments