@@ -112,6 +112,11 @@ bool EVLIndVarSimplifyImpl::run(Loop &L) {
112112 if (!EnableEVLIndVarSimplify)
113113 return false ;
114114
115+ BasicBlock *LatchBlock = L.getLoopLatch ();
116+ ICmpInst *OrigLatchCmp = L.getLatchCmpInst ();
117+ if (!LatchBlock || !OrigLatchCmp)
118+ return false ;
119+
115120 InductionDescriptor IVD;
116121 PHINode *IndVar = L.getInductionVariable (SE);
117122 if (!IndVar || !L.getInductionDescriptor (SE, IVD)) {
@@ -153,6 +158,7 @@ bool EVLIndVarSimplifyImpl::run(Loop &L) {
153158
154159 Value *EVLIndVar = nullptr ;
155160 Value *RemTC = nullptr ;
161+ Value *TC = nullptr ;
156162 auto IntrinsicMatch = m_Intrinsic<Intrinsic::experimental_get_vector_length>(
157163 m_Value (RemTC), m_SpecificInt (VF),
158164 /* Scalable=*/ m_SpecificInt (1 ));
@@ -192,43 +198,37 @@ bool EVLIndVarSimplifyImpl::run(Loop &L) {
192198 LLVM_DEBUG (dbgs () << " Found candidate PN of EVL-based IndVar: " << PN
193199 << " \n " );
194200
195- // Check 3: Pattern match to find the EVL-based index.
201+ // Check 3: Pattern match to find the EVL-based index and total trip count
202+ // (TC).
196203 if (match (RecValue,
197204 m_c_Add (m_ZExtOrSelf (IntrinsicMatch), m_Specific (&PN))) &&
198- match (RemTC, m_Sub (m_Value (), m_Specific (&PN)))) {
205+ match (RemTC, m_Sub (m_Value (TC ), m_Specific (&PN)))) {
199206 EVLIndVar = RecValue;
200207 break ;
201208 }
202209 }
203210
204- if (!EVLIndVar)
205- return false ;
206-
207- const SCEV *BTC = SE.getBackedgeTakenCount (&L);
208- LLVM_DEBUG (dbgs () << " BTC: " << *BTC << " \n " );
209- if (isa<SCEVCouldNotCompute>(BTC))
211+ if (!EVLIndVar || !TC)
210212 return false ;
211213
212- const SCEV *VFV = SE.getConstant (BTC->getType (), VF);
213- VFV = SE.getMulExpr (VFV, SE.getVScale (VFV->getType ()));
214- const SCEV *ExitValV = SE.getMulExpr (BTC, VFV);
215- LLVM_DEBUG (dbgs () << " ExitVal: " << *ExitValV << " \n " );
214+ LLVM_DEBUG (dbgs () << " Using " << *EVLIndVar << " for EVL-based IndVar\n " );
216215
217216 // Create an EVL-based comparison and replace the branch to use it as
218217 // predicate.
219- ICmpInst *OrigLatchCmp = L.getLatchCmpInst ();
220- const DataLayout &DL = L.getHeader ()->getDataLayout ();
221- SCEVExpander Expander (SE, DL, " evl.iv.exitcondition" );
222- if (!Expander.isSafeToExpandAt (ExitValV, OrigLatchCmp))
223- return false ;
224218
225- LLVM_DEBUG (dbgs () << " Using " << *EVLIndVar << " for EVL-based IndVar\n " );
219+ // Loop::getLatchCmpInst check at the beginning of this function has ensured
220+ // that latch block ends in a conditional branch.
221+ auto *LatchBranch = cast<BranchInst>(LatchBlock->getTerminator ());
222+ assert (LatchBranch->getNumSuccessors () == 2 );
223+ ICmpInst::Predicate Pred;
224+ if (LatchBranch->getSuccessor (0 ) == L.getHeader ())
225+ Pred = ICmpInst::ICMP_ULT;
226+ else
227+ Pred = ICmpInst::ICMP_UGE;
226228
227- Value *ExitVal =
228- Expander.expandCodeFor (ExitValV, EVLIndVar->getType (), OrigLatchCmp);
229229 IRBuilder<> Builder (OrigLatchCmp);
230- auto *NewPred = Builder.CreateICmp (ICmpInst::ICMP_UGT , EVLIndVar, ExitVal );
231- OrigLatchCmp->replaceAllUsesWith (NewPred );
230+ auto *NewLatchCmp = Builder.CreateICmp (Pred , EVLIndVar, TC );
231+ OrigLatchCmp->replaceAllUsesWith (NewLatchCmp );
232232
233233 // llvm::RecursivelyDeleteDeadPHINode only deletes cycles whose values are
234234 // not used outside the cycles. However, in this case the now-RAUW-ed
0 commit comments