@@ -548,6 +548,11 @@ class InnerLoopVectorizer {
548548 Value *VectorTripCount, BasicBlock *MiddleBlock,
549549 VPTransformState &State);
550550
551+ void fixupEarlyExitIVUsers (PHINode *OrigPhi, const InductionDescriptor &II,
552+ BasicBlock *VectorEarlyExitBB,
553+ BasicBlock *MiddleBlock, VPlan &Plan,
554+ VPTransformState &State);
555+
551556 // / Iteratively sink the scalarized operands of a predicated instruction into
552557 // / the block that was created for it.
553558 void sinkScalarOperands (Instruction *PredInst);
@@ -2775,6 +2780,23 @@ BasicBlock *InnerLoopVectorizer::createVectorizedLoopSkeleton(
27752780 return LoopVectorPreHeader;
27762781}
27772782
2783+ static bool isValueIncomingFromBlock (BasicBlock *ExitingBB, Value *V,
2784+ Instruction *UI) {
2785+ PHINode *PHI = dyn_cast<PHINode>(UI);
2786+ assert (PHI && " Expected LCSSA form" );
2787+
2788+ // If this loop has an uncountable early exit then there could be
2789+ // different users of OrigPhi with either:
2790+ // 1. Multiple users, because each exiting block (countable or
2791+ // uncountable) jumps to the same exit block, or ..
2792+ // 2. A single user with an incoming value from a countable or
2793+ // uncountable exiting block.
2794+ // In both cases there is no guarantee this came from a countable exiting
2795+ // block, i.e. the latch.
2796+ int Index = PHI->getBasicBlockIndex (ExitingBB);
2797+ return Index != -1 && PHI->getIncomingValue (Index) == V;
2798+ }
2799+
27782800// Fix up external users of the induction variable. At this point, we are
27792801// in LCSSA form, with all external PHIs that use the IV having one input value,
27802802// coming from the remainder loop. We need those PHIs to also have a correct
@@ -2790,19 +2812,20 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
27902812 // We allow both, but they, obviously, have different values.
27912813
27922814 DenseMap<Value *, Value *> MissingVals;
2815+ BasicBlock *OrigLoopLatch = OrigLoop->getLoopLatch ();
27932816
27942817 Value *EndValue = cast<PHINode>(OrigPhi->getIncomingValueForBlock (
27952818 OrigLoop->getLoopPreheader ()))
27962819 ->getIncomingValueForBlock (MiddleBlock);
27972820
27982821 // An external user of the last iteration's value should see the value that
27992822 // the remainder loop uses to initialize its own IV.
2800- Value *PostInc = OrigPhi->getIncomingValueForBlock (OrigLoop-> getLoopLatch () );
2823+ Value *PostInc = OrigPhi->getIncomingValueForBlock (OrigLoopLatch );
28012824 for (User *U : PostInc->users ()) {
28022825 Instruction *UI = cast<Instruction>(U);
28032826 if (!OrigLoop->contains (UI)) {
2804- assert (isa<PHINode>(UI) && " Expected LCSSA form " );
2805- MissingVals[UI ] = EndValue;
2827+ if ( isValueIncomingFromBlock (OrigLoopLatch, PostInc, UI))
2828+ MissingVals[cast<PHINode>(UI) ] = EndValue;
28062829 }
28072830 }
28082831
@@ -2812,7 +2835,9 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
28122835 for (User *U : OrigPhi->users ()) {
28132836 auto *UI = cast<Instruction>(U);
28142837 if (!OrigLoop->contains (UI)) {
2815- assert (isa<PHINode>(UI) && " Expected LCSSA form" );
2838+ if (!isValueIncomingFromBlock (OrigLoopLatch, OrigPhi, UI))
2839+ continue ;
2840+
28162841 IRBuilder<> B (MiddleBlock->getTerminator ());
28172842
28182843 // Fast-math-flags propagate from the original induction instruction.
@@ -2842,18 +2867,6 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
28422867 }
28432868 }
28442869
2845- assert ((MissingVals.empty () ||
2846- all_of (MissingVals,
2847- [MiddleBlock, this ](const std::pair<Value *, Value *> &P) {
2848- return all_of (
2849- predecessors (cast<Instruction>(P.first )->getParent ()),
2850- [MiddleBlock, this ](BasicBlock *Pred) {
2851- return Pred == MiddleBlock ||
2852- Pred == OrigLoop->getLoopLatch ();
2853- });
2854- })) &&
2855- " Expected escaping values from latch/middle.block only" );
2856-
28572870 for (auto &I : MissingVals) {
28582871 PHINode *PHI = cast<PHINode>(I.first );
28592872 // One corner case we have to handle is two IVs "chasing" each-other,
@@ -2866,6 +2879,102 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
28662879 }
28672880}
28682881
2882+ void InnerLoopVectorizer::fixupEarlyExitIVUsers (PHINode *OrigPhi,
2883+ const InductionDescriptor &II,
2884+ BasicBlock *VectorEarlyExitBB,
2885+ BasicBlock *MiddleBlock,
2886+ VPlan &Plan,
2887+ VPTransformState &State) {
2888+ // There are two kinds of external IV usages - those that use the value
2889+ // computed in the last iteration (the PHI) and those that use the penultimate
2890+ // value (the value that feeds into the phi from the loop latch).
2891+ // We allow both, but they, obviously, have different values.
2892+ DenseMap<Value *, Value *> MissingVals;
2893+ BasicBlock *OrigLoopLatch = OrigLoop->getLoopLatch ();
2894+ BasicBlock *EarlyExitingBB = Legal->getUncountableEarlyExitingBlock ();
2895+ Value *PostInc = OrigPhi->getIncomingValueForBlock (OrigLoopLatch);
2896+
2897+ // Obtain the canonical IV, since we have to use the most recent value
2898+ // before exiting the loop early. This is unlike fixupIVUsers, which has
2899+ // the luxury of using the end value in the middle block.
2900+ VPBasicBlock *EntryVPBB = Plan.getVectorLoopRegion ()->getEntryBasicBlock ();
2901+ // NOTE: We cannot call Plan.getCanonicalIV() here because the original
2902+ // recipe created whilst building plans is no longer valid.
2903+ VPHeaderPHIRecipe *CanonicalIVR =
2904+ cast<VPHeaderPHIRecipe>(&*EntryVPBB->begin ());
2905+ Value *CanonicalIV = State.get (CanonicalIVR->getVPSingleValue (), true );
2906+
2907+ // Search for the mask that drove us to exit early.
2908+ VPBasicBlock *EarlyExitVPBB = Plan.getVectorLoopRegion ()->getEarlyExit ();
2909+ VPBasicBlock *MiddleSplitVPBB =
2910+ cast<VPBasicBlock>(EarlyExitVPBB->getSinglePredecessor ());
2911+ VPInstruction *BranchOnCond =
2912+ cast<VPInstruction>(MiddleSplitVPBB->getTerminator ());
2913+ assert (BranchOnCond->getOpcode () == VPInstruction::BranchOnCond &&
2914+ " Expected middle.split block terminator to be a branch-on-cond" );
2915+ VPInstruction *ScalarEarlyExitCond =
2916+ cast<VPInstruction>(BranchOnCond->getOperand (0 ));
2917+ assert (
2918+ ScalarEarlyExitCond->getOpcode () == VPInstruction::AnyOf &&
2919+ " Expected middle.split block terminator branch condition to be any-of" );
2920+ VPValue *VectorEarlyExitCond = ScalarEarlyExitCond->getOperand (0 );
2921+ // Finally get the mask that led us into the early exit block.
2922+ Value *EarlyExitMask = State.get (VectorEarlyExitCond);
2923+
2924+ // Calculate the IV step.
2925+ VPValue *StepVPV = Plan.getSCEVExpansion (II.getStep ());
2926+ assert (StepVPV && " step must have been expanded during VPlan execution" );
2927+ Value *Step = StepVPV->isLiveIn () ? StepVPV->getLiveInIRValue ()
2928+ : State.get (StepVPV, VPLane (0 ));
2929+
2930+ auto FixUpPhi = [&](Instruction *UI, bool PostInc) -> Value * {
2931+ IRBuilder<> B (VectorEarlyExitBB->getTerminator ());
2932+ assert (isa<PHINode>(UI) && " Expected LCSSA form" );
2933+
2934+ // Fast-math-flags propagate from the original induction instruction.
2935+ if (isa_and_nonnull<FPMathOperator>(II.getInductionBinOp ()))
2936+ B.setFastMathFlags (II.getInductionBinOp ()->getFastMathFlags ());
2937+
2938+ Type *CtzType = CanonicalIV->getType ();
2939+ Value *Ctz = B.CreateCountTrailingZeroElems (CtzType, EarlyExitMask);
2940+ Ctz = B.CreateAdd (Ctz, cast<PHINode>(CanonicalIV));
2941+ if (PostInc)
2942+ Ctz = B.CreateAdd (Ctz, ConstantInt::get (CtzType, 1 ));
2943+
2944+ Value *Escape = emitTransformedIndex (B, Ctz, II.getStartValue (), Step,
2945+ II.getKind (), II.getInductionBinOp ());
2946+ Escape->setName (" ind.early.escape" );
2947+ return Escape;
2948+ };
2949+
2950+ for (User *U : PostInc->users ()) {
2951+ auto *UI = cast<Instruction>(U);
2952+ if (!OrigLoop->contains (UI)) {
2953+ if (isValueIncomingFromBlock (EarlyExitingBB, PostInc, UI))
2954+ MissingVals[UI] = FixUpPhi (UI, true );
2955+ }
2956+ }
2957+
2958+ for (User *U : OrigPhi->users ()) {
2959+ auto *UI = cast<Instruction>(U);
2960+ if (!OrigLoop->contains (UI)) {
2961+ if (isValueIncomingFromBlock (EarlyExitingBB, OrigPhi, UI))
2962+ MissingVals[UI] = FixUpPhi (UI, false );
2963+ }
2964+ }
2965+
2966+ for (auto &I : MissingVals) {
2967+ PHINode *PHI = cast<PHINode>(I.first );
2968+ // One corner case we have to handle is two IVs "chasing" each-other,
2969+ // that is %IV2 = phi [...], [ %IV1, %latch ]
2970+ // In this case, if IV1 has an external use, we need to avoid adding both
2971+ // "last value of IV1" and "penultimate value of IV2". So, verify that we
2972+ // don't already have an incoming value for the middle block.
2973+ if (PHI->getBasicBlockIndex (VectorEarlyExitBB) == -1 )
2974+ PHI->addIncoming (I.second , VectorEarlyExitBB);
2975+ }
2976+ }
2977+
28692978namespace {
28702979
28712980struct CSEDenseMapInfo {
@@ -2985,6 +3094,20 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State) {
29853094 PSE.getSE ()->forgetLoop (OrigLoop);
29863095 PSE.getSE ()->forgetBlockAndLoopDispositions ();
29873096
3097+ // When dealing with uncountable early exits we create middle.split blocks
3098+ // between the vector loop region and the exit block. These blocks need
3099+ // adding to any outer loop.
3100+ VPRegionBlock *VectorRegion = State.Plan ->getVectorLoopRegion ();
3101+ Loop *OuterLoop = OrigLoop->getParentLoop ();
3102+ if (Legal->hasUncountableEarlyExit () && OuterLoop) {
3103+ BasicBlock *OrigEarlyExitBB = Legal->getUncountableEarlyExitBlock ();
3104+ if (Loop *EEL = LI->getLoopFor (OrigEarlyExitBB)) {
3105+ BasicBlock *VectorEarlyExitBB =
3106+ State.CFG .VPBB2IRBB [VectorRegion->getEarlyExit ()];
3107+ EEL->addBasicBlockToLoop (VectorEarlyExitBB, *LI);
3108+ }
3109+ }
3110+
29883111 // After vectorization, the exit blocks of the original loop will have
29893112 // additional predecessors. Invalidate SCEVs for the exit phis in case SE
29903113 // looked through single-entry phis.
@@ -3012,15 +3135,23 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State) {
30123135 getOrCreateVectorTripCount (nullptr ), LoopMiddleBlock, State);
30133136 }
30143137
3138+ if (Legal->hasUncountableEarlyExit ()) {
3139+ VPBasicBlock *VectorEarlyExitVPBB =
3140+ cast<VPBasicBlock>(VectorRegion->getEarlyExit ());
3141+ BasicBlock *VectorEarlyExitBB = State.CFG .VPBB2IRBB [VectorEarlyExitVPBB];
3142+ for (const auto &Entry : Legal->getInductionVars ())
3143+ fixupEarlyExitIVUsers (Entry.first , Entry.second , VectorEarlyExitBB,
3144+ LoopMiddleBlock, Plan, State);
3145+ }
3146+
30153147 // Don't apply optimizations below when no vector region remains, as they all
30163148 // require a vector loop at the moment.
3017- if (!State. Plan -> getVectorLoopRegion () )
3149+ if (!VectorRegion )
30183150 return ;
30193151
30203152 for (Instruction *PI : PredicatedInstructions)
30213153 sinkScalarOperands (&*PI);
30223154
3023- VPRegionBlock *VectorRegion = State.Plan ->getVectorLoopRegion ();
30243155 VPBasicBlock *HeaderVPBB = VectorRegion->getEntryBasicBlock ();
30253156 BasicBlock *HeaderBB = State.CFG .VPBB2IRBB [HeaderVPBB];
30263157
@@ -8948,6 +9079,10 @@ static void addScalarResumePhis(VPRecipeBuilder &Builder, VPlan &Plan) {
89489079 continue ;
89499080 }
89509081
9082+ assert (!Plan.getVectorLoopRegion ()->getEarlyExit () &&
9083+ " Cannot handle "
9084+ " first-order recurrences with uncountable early exits" );
9085+
89519086 // The backedge value provides the value to resume coming out of a loop,
89529087 // which for FORs is a vector whose last element needs to be extracted. The
89539088 // start value provides the value if the loop is bypassed.
@@ -9056,9 +9191,8 @@ collectUsersInExitBlocks(Loop *OrigLoop, VPRecipeBuilder &Builder,
90569191 // Exit values for inductions are computed and updated outside of VPlan
90579192 // and independent of induction recipes.
90589193 // TODO: Compute induction exit values in VPlan.
9059- if (isOptimizableIVOrUse (V) &&
9060- ExitVPBB->getSinglePredecessor () == MiddleVPBB)
9061- continue ;
9194+ if (isOptimizableIVOrUse (V))
9195+ V = VPValue::getNull ();
90629196 ExitUsersToFix.insert (ExitIRI);
90639197 ExitIRI->addOperand (V);
90649198 }
@@ -9085,18 +9219,30 @@ addUsersInExitBlocks(VPlan &Plan,
90859219 for (const auto &[Idx, Op] : enumerate(ExitIRI->operands ())) {
90869220 // Pass live-in values used by exit phis directly through to their users
90879221 // in the exit block.
9088- if (Op->isLiveIn ())
9222+ if (Op->isLiveIn () || Op-> isNull () )
90899223 continue ;
90909224
90919225 // Currently only live-ins can be used by exit values from blocks not
90929226 // exiting via the vector latch through to the middle block.
9093- if (ExitIRI->getParent ()->getSinglePredecessor () != MiddleVPBB)
9094- return false ;
9095-
90969227 LLVMContext &Ctx = ExitIRI->getInstruction ().getContext ();
9097- VPValue *Ext = B.createNaryOp (VPInstruction::ExtractFromEnd,
9098- {Op, Plan.getOrAddLiveIn (ConstantInt::get (
9099- IntegerType::get (Ctx, 32 ), 1 ))});
9228+ VPValue *Ext;
9229+ VPBasicBlock *PredVPBB =
9230+ cast<VPBasicBlock>(ExitIRI->getParent ()->getPredecessors ()[Idx]);
9231+ if (PredVPBB != MiddleVPBB) {
9232+ VPBasicBlock *VectorEarlyExitVPBB =
9233+ Plan.getVectorLoopRegion ()->getEarlyExit ();
9234+ VPBuilder B2 (VectorEarlyExitVPBB,
9235+ VectorEarlyExitVPBB->getFirstNonPhi ());
9236+ assert (ExitIRI->getParent ()->getNumPredecessors () <= 2 );
9237+ VPValue *EarlyExitMask =
9238+ Plan.getVectorLoopRegion ()->getVectorEarlyExitCond ();
9239+ Ext = B2.createNaryOp (VPInstruction::ExtractFirstActive,
9240+ {Op, EarlyExitMask});
9241+ } else {
9242+ Ext = B.createNaryOp (VPInstruction::ExtractFromEnd,
9243+ {Op, Plan.getOrAddLiveIn (ConstantInt::get (
9244+ IntegerType::get (Ctx, 32 ), 1 ))});
9245+ }
91009246 ExitIRI->setOperand (Idx, Ext);
91019247 }
91029248 }
0 commit comments