@@ -1371,27 +1371,45 @@ class LoopVectorizationCostModel {
13711371 return InterleaveInfo.getInterleaveGroup (Instr);
13721372 }
13731373
1374+ // / Calculate in advance whether a scalar epilogue is required when
1375+ // / vectorizing and not vectorizing. If \p Invalidate is true then
1376+ // / invalidate a previous decision.
1377+ void collectScalarEpilogueRequirements (bool Invalidate) {
1378+ auto NeedsScalarEpilogue = [&](bool IsVectorizing) -> bool {
1379+ if (!isScalarEpilogueAllowed ()) {
1380+ LLVM_DEBUG (dbgs () << " LV: Loop does not require scalar epilogue" );
1381+ return false ;
1382+ }
1383+ // If we might exit from anywhere but the latch, must run the exiting
1384+ // iteration in scalar form.
1385+ if (TheLoop->getExitingBlock () != TheLoop->getLoopLatch ()) {
1386+ LLVM_DEBUG (dbgs () << " LV: Loop requires scalar epilogue: multiple exits" );
1387+ return true ;
1388+ }
1389+ if (IsVectorizing && InterleaveInfo.requiresScalarEpilogue ()) {
1390+ LLVM_DEBUG (dbgs () << " LV: Loop requires scalar epilogue: "
1391+ " interleaved group requires scalar epilogue" );
1392+ return true ;
1393+ }
1394+ LLVM_DEBUG (dbgs () << " LV: Loop does not require scalar epilogue" );
1395+ return false ;
1396+ };
1397+
1398+ assert ((Invalidate || !RequiresScalarEpilogue) &&
1399+ " Already determined scalar epilogue requirements!" );
1400+ std::pair<bool ,bool > Result;
1401+ Result.first = NeedsScalarEpilogue (true );
1402+ LLVM_DEBUG (dbgs () << " , when vectorizing\n " );
1403+ Result.second = NeedsScalarEpilogue (false );
1404+ LLVM_DEBUG (dbgs () << " , when not vectorizing\n " );
1405+ RequiresScalarEpilogue = Result;
1406+ }
1407+
13741408 // / Returns true if we're required to use a scalar epilogue for at least
13751409 // / the final iteration of the original loop.
13761410 bool requiresScalarEpilogue (bool IsVectorizing) const {
1377- if (!isScalarEpilogueAllowed ()) {
1378- LLVM_DEBUG (dbgs () << " LV: Loop does not require scalar epilogue\n " );
1379- return false ;
1380- }
1381- // If we might exit from anywhere but the latch, must run the exiting
1382- // iteration in scalar form.
1383- if (TheLoop->getExitingBlock () != TheLoop->getLoopLatch ()) {
1384- LLVM_DEBUG (
1385- dbgs () << " LV: Loop requires scalar epilogue: multiple exits\n " );
1386- return true ;
1387- }
1388- if (IsVectorizing && InterleaveInfo.requiresScalarEpilogue ()) {
1389- LLVM_DEBUG (dbgs () << " LV: Loop requires scalar epilogue: "
1390- " interleaved group requires scalar epilogue\n " );
1391- return true ;
1392- }
1393- LLVM_DEBUG (dbgs () << " LV: Loop does not require scalar epilogue\n " );
1394- return false ;
1411+ auto &CachedResult = *RequiresScalarEpilogue;
1412+ return IsVectorizing ? CachedResult.first : CachedResult.second ;
13951413 }
13961414
13971415 // / Returns true if we're required to use a scalar epilogue for at least
@@ -1415,6 +1433,15 @@ class LoopVectorizationCostModel {
14151433 return ScalarEpilogueStatus == CM_ScalarEpilogueAllowed;
14161434 }
14171435
1436+ // / Update the ScalarEpilogueStatus to a new value, potentially triggering a
1437+ // / recalculation of the scalar epilogue requirements.
1438+ void setScalarEpilogueStatus (ScalarEpilogueLowering Status) {
1439+ bool Changed = ScalarEpilogueStatus != Status;
1440+ ScalarEpilogueStatus = Status;
1441+ if (Changed)
1442+ collectScalarEpilogueRequirements (/* Invalidate=*/ true );
1443+ }
1444+
14181445 // / Returns the TailFoldingStyle that is best for the current loop.
14191446 TailFoldingStyle getTailFoldingStyle (bool IVUpdateMayOverflow = true ) const {
14201447 if (!ChosenTailFoldingStyle)
@@ -1767,6 +1794,9 @@ class LoopVectorizationCostModel {
17671794
17681795 // / All element types found in the loop.
17691796 SmallPtrSet<Type *, 16 > ElementTypesInLoop;
1797+
1798+ // / Keeps track of whether we require a scalar epilogue.
1799+ std::optional<std::pair<bool ,bool >> RequiresScalarEpilogue;
17701800};
17711801} // end namespace llvm
17721802
@@ -4034,7 +4064,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
40344064 if (ScalarEpilogueStatus == CM_ScalarEpilogueNotNeededUsePredicate) {
40354065 LLVM_DEBUG (dbgs () << " LV: Cannot fold tail by masking: vectorize with a "
40364066 " scalar epilogue instead.\n " );
4037- ScalarEpilogueStatus = CM_ScalarEpilogueAllowed;
4067+ setScalarEpilogueStatus ( CM_ScalarEpilogueAllowed) ;
40384068 return computeFeasibleMaxVF (MaxTC, UserVF, false );
40394069 }
40404070 return FixedScalableVFPair::getNone ();
@@ -4050,6 +4080,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
40504080 // Note: There is no need to invalidate any cost modeling decisions here, as
40514081 // non where taken so far.
40524082 InterleaveInfo.invalidateGroupsRequiringScalarEpilogue ();
4083+ collectScalarEpilogueRequirements (/* Invalidate=*/ true );
40534084 }
40544085
40554086 FixedScalableVFPair MaxFactors = computeFeasibleMaxVF (MaxTC, UserVF, true );
@@ -4115,7 +4146,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
41154146 if (ScalarEpilogueStatus == CM_ScalarEpilogueNotNeededUsePredicate) {
41164147 LLVM_DEBUG (dbgs () << " LV: Cannot fold tail by masking: vectorize with a "
41174148 " scalar epilogue instead.\n " );
4118- ScalarEpilogueStatus = CM_ScalarEpilogueAllowed;
4149+ setScalarEpilogueStatus ( CM_ScalarEpilogueAllowed) ;
41194150 return MaxFactors;
41204151 }
41214152
@@ -6957,6 +6988,7 @@ LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) {
69576988 if (!OrigLoop->isInnermost ()) {
69586989 // If the user doesn't provide a vectorization factor, determine a
69596990 // reasonable one.
6991+ CM.collectScalarEpilogueRequirements (/* Invalidate=*/ false );
69606992 if (UserVF.isZero ()) {
69616993 VF = determineVPlanVF (TTI, CM);
69626994 LLVM_DEBUG (dbgs () << " LV: VPlan computed VF " << VF << " .\n " );
@@ -7001,6 +7033,7 @@ LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) {
70017033
70027034void LoopVectorizationPlanner::plan (ElementCount UserVF, unsigned UserIC) {
70037035 assert (OrigLoop->isInnermost () && " Inner loop expected." );
7036+ CM.collectScalarEpilogueRequirements (/* Invalidate=*/ false );
70047037 CM.collectValuesToIgnore ();
70057038 CM.collectElementTypesForWidening ();
70067039
@@ -7015,11 +7048,13 @@ void LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) {
70157048 dbgs ()
70167049 << " LV: Invalidate all interleaved groups due to fold-tail by masking "
70177050 " which requires masked-interleaved support.\n " );
7018- if (CM.InterleaveInfo .invalidateGroups ())
7051+ if (CM.InterleaveInfo .invalidateGroups ()) {
70197052 // Invalidating interleave groups also requires invalidating all decisions
70207053 // based on them, which includes widening decisions and uniform and scalar
70217054 // values.
70227055 CM.invalidateCostModelingDecisions ();
7056+ CM.collectScalarEpilogueRequirements (/* Invalidate=*/ true );
7057+ }
70237058 }
70247059
70257060 if (CM.foldTailByMasking ())
0 commit comments