@@ -737,67 +737,76 @@ static VPWidenInductionRecipe *getOptimizableIVOf(VPValue *VPV) {
737737 return IsWideIVInc () ? WideIV : nullptr ;
738738}
739739
740- void VPlanTransforms::optimizeInductionExitUsers (
741- VPlan &Plan, DenseMap<VPValue *, VPValue *> &EndValues) {
740+ // / Attempts to optimize the induction variable exit values for users in the
741+ // / exit block coming from the latch in the original scalar loop.
742+ static VPValue *
743+ optimizeLatchExitInductionUser (VPlan &Plan, VPTypeAnalysis &TypeInfo,
744+ VPBlockBase *PredVPBB, VPValue *Op,
745+ DenseMap<VPValue *, VPValue *> &EndValues) {
742746 using namespace VPlanPatternMatch ;
743- SmallVector<VPIRBasicBlock *> ExitVPBBs (Plan.getExitBlocks ());
744- if (ExitVPBBs.size () != 1 )
745- return ;
746747
747- VPIRBasicBlock *ExitVPBB = ExitVPBBs[0 ];
748- VPBlockBase *PredVPBB = ExitVPBB->getSinglePredecessor ();
749- if (!PredVPBB)
750- return ;
751- assert (PredVPBB == Plan.getMiddleBlock () &&
752- " predecessor must be the middle block" );
753-
754- VPTypeAnalysis TypeInfo (Plan.getCanonicalIV ()->getScalarType ());
755- VPBuilder B (Plan.getMiddleBlock ()->getTerminator ());
756- for (VPRecipeBase &R : *ExitVPBB) {
757- auto *ExitIRI = cast<VPIRInstruction>(&R);
758- if (!isa<PHINode>(ExitIRI->getInstruction ()))
759- break ;
748+ VPValue *Incoming;
749+ if (!match (Op, m_VPInstruction<VPInstruction::ExtractFromEnd>(
750+ m_VPValue (Incoming), m_SpecificInt (1 ))))
751+ return nullptr ;
760752
761- VPValue *Incoming;
762- if (!match (ExitIRI->getOperand (0 ),
763- m_VPInstruction<VPInstruction::ExtractFromEnd>(
764- m_VPValue (Incoming), m_SpecificInt (1 ))))
765- continue ;
753+ auto *WideIV = getOptimizableIVOf (Incoming);
754+ if (!WideIV)
755+ return nullptr ;
766756
767- auto *WideIV = getOptimizableIVOf (Incoming);
768- if (!WideIV)
769- continue ;
770- VPValue *EndValue = EndValues.lookup (WideIV);
771- assert (EndValue && " end value must have been pre-computed" );
757+ VPValue *EndValue = EndValues.lookup (WideIV);
758+ assert (EndValue && " end value must have been pre-computed" );
759+
760+ // `getOptimizableIVOf()` always returns the pre-incremented IV, so if it
761+ // changed it means the exit is using the incremented value, so we don't
762+ // need to subtract the step.
763+ if (Incoming != WideIV)
764+ return EndValue;
765+
766+ // Otherwise, subtract the step from the EndValue.
767+ VPBuilder B (cast<VPBasicBlock>(PredVPBB)->getTerminator ());
768+ VPValue *Step = WideIV->getStepValue ();
769+ Type *ScalarTy = TypeInfo.inferScalarType (WideIV);
770+ if (ScalarTy->isIntegerTy ())
771+ return B.createNaryOp (Instruction::Sub, {EndValue, Step}, {}, " ind.escape" );
772+ if (ScalarTy->isPointerTy ()) {
773+ auto *Zero = Plan.getOrAddLiveIn (
774+ ConstantInt::get (Step->getLiveInIRValue ()->getType (), 0 ));
775+ return B.createPtrAdd (EndValue,
776+ B.createNaryOp (Instruction::Sub, {Zero, Step}), {},
777+ " ind.escape" );
778+ }
779+ if (ScalarTy->isFloatingPointTy ()) {
780+ const auto &ID = WideIV->getInductionDescriptor ();
781+ return B.createNaryOp (
782+ ID.getInductionBinOp ()->getOpcode () == Instruction::FAdd
783+ ? Instruction::FSub
784+ : Instruction::FAdd,
785+ {EndValue, Step}, {ID.getInductionBinOp ()->getFastMathFlags ()});
786+ }
787+ llvm_unreachable (" all possible induction types must be handled" );
788+ return nullptr ;
789+ }
772790
773- if (Incoming != WideIV) {
774- ExitIRI->setOperand (0 , EndValue);
775- continue ;
776- }
791+ void VPlanTransforms::optimizeInductionExitUsers (
792+ VPlan &Plan, DenseMap<VPValue *, VPValue *> &EndValues) {
793+ VPBlockBase *MiddleVPBB = Plan.getMiddleBlock ();
794+ VPTypeAnalysis TypeInfo (Plan.getCanonicalIV ()->getScalarType ());
795+ for (VPIRBasicBlock *ExitVPBB : Plan.getExitBlocks ()) {
796+ for (VPRecipeBase &R : *ExitVPBB) {
797+ auto *ExitIRI = cast<VPIRInstruction>(&R);
798+ if (!isa<PHINode>(ExitIRI->getInstruction ()))
799+ break ;
777800
778- VPValue *Escape = nullptr ;
779- VPValue *Step = WideIV->getStepValue ();
780- Type *ScalarTy = TypeInfo.inferScalarType (WideIV);
781- if (ScalarTy->isIntegerTy ()) {
782- Escape =
783- B.createNaryOp (Instruction::Sub, {EndValue, Step}, {}, " ind.escape" );
784- } else if (ScalarTy->isPointerTy ()) {
785- auto *Zero = Plan.getOrAddLiveIn (
786- ConstantInt::get (Step->getLiveInIRValue ()->getType (), 0 ));
787- Escape = B.createPtrAdd (EndValue,
788- B.createNaryOp (Instruction::Sub, {Zero, Step}),
789- {}, " ind.escape" );
790- } else if (ScalarTy->isFloatingPointTy ()) {
791- const auto &ID = WideIV->getInductionDescriptor ();
792- Escape = B.createNaryOp (
793- ID.getInductionBinOp ()->getOpcode () == Instruction::FAdd
794- ? Instruction::FSub
795- : Instruction::FAdd,
796- {EndValue, Step}, {ID.getInductionBinOp ()->getFastMathFlags ()});
797- } else {
798- llvm_unreachable (" all possible induction types must be handled" );
801+ for (auto [Idx, PredVPBB] : enumerate(ExitVPBB->getPredecessors ())) {
802+ if (PredVPBB == MiddleVPBB)
803+ if (VPValue *Escape = optimizeLatchExitInductionUser (
804+ Plan, TypeInfo, PredVPBB, ExitIRI->getOperand (Idx),
805+ EndValues))
806+ ExitIRI->setOperand (Idx, Escape);
807+ // TODO: Optimize early exit induction users in follow-on patch.
808+ }
799809 }
800- ExitIRI->setOperand (0 , Escape);
801810 }
802811}
803812
0 commit comments