@@ -680,62 +680,70 @@ void VPlanTransforms::addMinimumIterationCheck(
680680 // vector trip count is zero. This check also covers the case where adding one
681681 // to the backedge-taken count overflowed leading to an incorrect trip count
682682 // of zero. In this case we will also jump to the scalar loop.
683- auto P = RequiresScalarEpilogue ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_ULT;
683+ CmpInst::Predicate CmpPred =
684+ RequiresScalarEpilogue ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_ULT;
684685 // If tail is to be folded, vector loop takes care of all iterations.
685- const SCEV *Count = vputils::getSCEVExprForVPValue (Plan.getTripCount (), SE);
686- Type *CountTy = Count->getType ();
687- auto CreateStep = [&]() -> const SCEV * {
688- const SCEV *VFxUF = SE.getElementCount (CountTy, (VF * UF), SCEV::FlagNUW);
689- // Create step with max(MinProTripCount, UF * VF).
690- if (UF * VF.getKnownMinValue () >= MinProfitableTripCount.getKnownMinValue ())
686+ VPValue *TripCountVPV = Plan.getTripCount ();
687+ const SCEV *TripCount = vputils::getSCEVExprForVPValue (TripCountVPV, SE);
688+ Type *TripCountTy = TripCount->getType ();
689+ auto CreateMinTripCount = [&]() -> const SCEV * {
690+ // Create or get max(MinProfitableTripCount, UF * VF) and return it.
691+ const SCEV *VFxUF =
692+ SE.getElementCount (TripCountTy, (VF * UF), SCEV::FlagNUW);
693+ const SCEV *MinProfitableTripCountSCEV =
694+ SE.getElementCount (TripCountTy, MinProfitableTripCount, SCEV::FlagNUW);
695+ const SCEV *Max = SE.getUMaxExpr (MinProfitableTripCountSCEV, VFxUF);
696+ if (!VF.isScalable ())
697+ return Max;
698+
699+ if (UF * VF.getKnownMinValue () >=
700+ MinProfitableTripCount.getKnownMinValue ()) {
701+ // TODO: SCEV should be able to simplify test.
691702 return VFxUF;
703+ }
692704
693- const SCEV *MinProfTC =
694- SE.getElementCount (CountTy, MinProfitableTripCount, SCEV::FlagNUW);
695- if (!VF.isScalable ())
696- return MinProfTC;
697- return SE.getUMaxExpr (MinProfTC, VFxUF);
705+ return Max;
698706 };
699707
700708 VPBasicBlock *EntryVPBB = Plan.getEntry ();
701709 VPBuilder Builder (EntryVPBB);
702- VPValue *CheckMinIters = Plan.getFalse ();
703- const SCEV *Step = CreateStep ();
710+ VPValue *TripCountCheck = Plan.getFalse ();
711+ const SCEV *Step = CreateMinTripCount ();
704712 if (!TailFolded) {
705713 // TODO: Emit unconditional branch to vector preheader instead of
706714 // conditional branch with known condition.
707- const SCEV *TripCountSCEV = SE.applyLoopGuards (Count , OrigLoop);
715+ TripCount = SE.applyLoopGuards (TripCount , OrigLoop);
708716 // Check if the trip count is < the step.
709- if (SE.isKnownPredicate (P, TripCountSCEV , Step)) {
717+ if (SE.isKnownPredicate (CmpPred, TripCount , Step)) {
710718 // TODO: Ensure step is at most the trip count when determining max VF and
711719 // UF, w/o tail folding.
712- CheckMinIters = Plan.getTrue ();
713- } else if (!SE.isKnownPredicate (CmpInst::getInversePredicate (P ),
714- TripCountSCEV , Step)) {
720+ TripCountCheck = Plan.getTrue ();
721+ } else if (!SE.isKnownPredicate (CmpInst::getInversePredicate (CmpPred ),
722+ TripCount , Step)) {
715723 // Generate the minimum iteration check only if we cannot prove the
716724 // check is known to be true, or known to be false.
717- CheckMinIters = Builder.createICmp (P, Plan. getTripCount (),
718- Builder.expandSCEV (Step, SE), DL,
719- " min.iters.check" );
720- } // else step known to be < trip count, use CheckMinIters preset to false.
725+ VPValue *MinTripCountVPV = Builder.createExpandSCEV (Step, SE);
726+ TripCountCheck = Builder.createICmp (
727+ CmpPred, TripCountVPV, MinTripCountVPV, DL, " min.iters.check" );
728+ } // else step known to be < trip count, use TripCountCheck preset to false.
721729 } else if (CheckNeededWithTailFolding) {
722730 // vscale is not necessarily a power-of-2, which means we cannot guarantee
723731 // an overflow to zero when updating induction variables and so an
724732 // additional overflow check is required before entering the vector loop.
725733
726734 // Get the maximum unsigned value for the type.
727- VPValue *MaxUIntTripCount = Plan.getOrAddLiveIn (
728- ConstantInt::get (CountTy , cast<IntegerType>(CountTy )->getMask ()));
729- VPValue *LHS = Builder. createNaryOp (Instruction::Sub,
730- {MaxUIntTripCount, Plan. getTripCount () },
731- DebugLoc::getUnknown ());
735+ VPValue *MaxUIntTripCount = Plan.getOrAddLiveIn (ConstantInt::get (
736+ TripCountTy , cast<IntegerType>(TripCountTy )->getMask ()));
737+ VPValue *DistanceToMax =
738+ Builder. createNaryOp (Instruction::Sub, {MaxUIntTripCount, TripCountVPV },
739+ DebugLoc::getUnknown ());
732740
733741 // Don't execute the vector loop if (UMax - n) < (VF * UF).
734- CheckMinIters = Builder.createICmp (ICmpInst::ICMP_ULT, LHS ,
735- Builder.expandSCEV (Step, SE), DL);
742+ TripCountCheck = Builder.createICmp (ICmpInst::ICMP_ULT, DistanceToMax ,
743+ Builder.createExpandSCEV (Step, SE), DL);
736744 }
737745 VPInstruction *Term =
738- Builder.createNaryOp (VPInstruction::BranchOnCond, {CheckMinIters }, DL);
746+ Builder.createNaryOp (VPInstruction::BranchOnCond, {TripCountCheck }, DL);
739747 if (MinItersBypassWeights) {
740748 MDBuilder MDB (Plan.getContext ());
741749 MDNode *BranchWeights = MDB.createBranchWeights (
0 commit comments