@@ -1098,29 +1098,80 @@ bool RewriteScheduleStage::initGCNSchedStage() {
10981098 }
10991099 }
11001100
1101- bool ShouldRewrite = false ;
1101+ unsigned ArchVGPRThreshold =
1102+ ST.getMaxNumVectorRegs (DAG.MF .getFunction ()).first ;
1103+
1104+ int64_t Cost = 0 ;
1105+ MBFI.calculate (MF, MBPI, *DAG.MLI );
11021106 for (unsigned RegionIdx = 0 ; RegionIdx < DAG.Regions .size (); RegionIdx++) {
11031107 if (!DAG.RegionsWithExcessArchVGPR [RegionIdx])
11041108 continue ;
11051109
1110+ unsigned MaxCombinedVGPRs = ST.getMaxNumVGPRs (MF);
1111+
1112+ auto PressureBefore = DAG.Pressure [RegionIdx];
1113+ unsigned UnifiedPressureBefore =
1114+ PressureBefore.getVGPRNum (true , ArchVGPRThreshold);
1115+ unsigned ArchPressureBefore =
1116+ PressureBefore.getArchVGPRNum (ArchVGPRThreshold);
1117+ unsigned AGPRPressureBefore = PressureBefore.getAGPRNum (ArchVGPRThreshold);
1118+ unsigned UnifiedSpillBefore =
1119+ UnifiedPressureBefore > MaxCombinedVGPRs
1120+ ? (UnifiedPressureBefore - MaxCombinedVGPRs)
1121+ : 0 ;
1122+ unsigned ArchSpillBefore =
1123+ ArchPressureBefore > ST.getAddressableNumArchVGPRs ()
1124+ ? (ArchPressureBefore - ST.getAddressableNumArchVGPRs ())
1125+ : 0 ;
1126+ unsigned AGPRSpillBefore =
1127+ AGPRPressureBefore > ST.getAddressableNumArchVGPRs ()
1128+ ? (AGPRPressureBefore - ST.getAddressableNumArchVGPRs ())
1129+ : 0 ;
1130+
1131+ unsigned SpillCostBefore =
1132+ std::max (UnifiedSpillBefore, (ArchSpillBefore + AGPRSpillBefore));
1133+
1134+
11061135 // For the cases we care about (i.e. ArchVGPR usage is greater than the
11071136 // addressable limit), rewriting alone should bring pressure to manageable
11081137 // level. If we find any such region, then the rewrite is potentially
11091138 // beneficial.
11101139 auto PressureAfter = DAG.getRealRegPressure (RegionIdx);
1111- unsigned MaxCombinedVGPRs = ST.getMaxNumVGPRs (MF);
1112- if (PressureAfter.getArchVGPRNum () <= ST.getAddressableNumArchVGPRs () &&
1113- PressureAfter.getVGPRNum (true ) <= MaxCombinedVGPRs) {
1114- ShouldRewrite = true ;
1115- break ;
1116- }
1140+ unsigned UnifiedPressureAfter =
1141+ PressureAfter.getVGPRNum (true , ArchVGPRThreshold);
1142+ unsigned ArchPressureAfter =
1143+ PressureAfter.getArchVGPRNum (ArchVGPRThreshold);
1144+ unsigned AGPRPressureAfter = PressureAfter.getAGPRNum (ArchVGPRThreshold);
1145+ unsigned UnifiedSpillAfter = UnifiedPressureAfter > MaxCombinedVGPRs
1146+ ? (UnifiedPressureAfter - MaxCombinedVGPRs)
1147+ : 0 ;
1148+ unsigned ArchSpillAfter =
1149+ ArchPressureAfter > ST.getAddressableNumArchVGPRs ()
1150+ ? (ArchPressureAfter - ST.getAddressableNumArchVGPRs ())
1151+ : 0 ;
1152+ unsigned AGPRSpillAfter =
1153+ AGPRPressureAfter > ST.getAddressableNumArchVGPRs ()
1154+ ? (AGPRPressureAfter - ST.getAddressableNumArchVGPRs ())
1155+ : 0 ;
1156+
1157+ unsigned SpillCostAfter =
1158+ std::max (UnifiedSpillAfter, (ArchSpillAfter + AGPRSpillAfter));
1159+
1160+ uint64_t EntryFreq = MBFI.getEntryFreq ().getFrequency ();
1161+ uint64_t BlockFreq =
1162+ EntryFreq ? MBFI.getBlockFreq (DAG.Regions [RegionIdx].first ->getParent ())
1163+ .getFrequency () / EntryFreq
1164+ : 1 ;
1165+
1166+ // Assumes perfect spilling -- giving edge to VGPR form.
1167+ Cost += ((int )SpillCostAfter - (int )SpillCostBefore) * (int )BlockFreq;
11171168 }
11181169
11191170 // If we find that we'll need to insert cross RC copies inside loop bodies,
11201171 // then bail
1172+ bool ShouldRewrite = Cost < 0 ;
11211173 if (ShouldRewrite) {
1122- CI.clear ();
1123- CI.compute (MF);
1174+ uint64_t EntryFreq = MBFI.getEntryFreq ().getFrequency ();
11241175
11251176 for (auto *DefMI : CrossRCUseCopies) {
11261177 auto DefReg = DefMI->getOperand (0 ).getReg ();
@@ -1137,11 +1188,16 @@ bool RewriteScheduleStage::initGCNSchedStage() {
11371188 if (!RequiredRC || SRI->hasAGPRs (RequiredRC))
11381189 continue ;
11391190
1140- unsigned DefDepth = CI.getCycleDepth (DefMI->getParent ());
1141- if (DefDepth && CI.getCycleDepth (UseMI.getParent ()) >= DefDepth) {
1142- ShouldRewrite = false ;
1191+ uint64_t UseFreq =
1192+ EntryFreq ? MBFI.getBlockFreq (UseMI.getParent ()).getFrequency () /
1193+ EntryFreq
1194+ : 1 ;
1195+
1196+ // Assumes no copy-reuse, giving edge to VGPR form.
1197+ Cost += UseFreq;
1198+ ShouldRewrite = Cost < 0 ;
1199+ if (!ShouldRewrite)
11431200 break ;
1144- }
11451201 }
11461202 if (!ShouldRewrite)
11471203 break ;
@@ -1596,7 +1652,8 @@ void GCNSchedStage::checkScheduling() {
15961652 DAG.RegionsWithExcessRP [RegionIdx] = true ;
15971653 }
15981654
1599- if (PressureAfter.getArchVGPRNum () > ST.getAddressableNumArchVGPRs ())
1655+ if (PressureAfter.getArchVGPRNum (ArchVGPRThreshold) >
1656+ ST.getAddressableNumArchVGPRs ())
16001657 DAG.RegionsWithExcessArchVGPR [RegionIdx] = true ;
16011658
16021659 // Revert if this region's schedule would cause a drop in occupancy or
0 commit comments