2222#include " llvm/IR/PatternMatch.h"
2323#include " llvm/InitializePasses.h"
2424#include " llvm/Pass.h"
25+ #include " llvm/Support/CommandLine.h"
2526#include " llvm/Support/Debug.h"
2627#include " llvm/Support/MathExtras.h"
2728#include " llvm/Support/raw_ostream.h"
@@ -33,6 +34,11 @@ using namespace llvm;
3334
3435STATISTIC (NumEliminatedCanonicalIV, " Number of canonical IVs we eliminated" );
3536
37+ static cl::opt<bool > EnableEVLIndVarSimplify (
38+ " enable-evl-indvar-simplify" ,
39+ cl::desc (" Enable EVL-based induction variable simplify Pass" ), cl::Hidden,
40+ cl::init(true ));
41+
3642namespace {
3743struct EVLIndVarSimplifyImpl {
3844 ScalarEvolution &SE;
@@ -62,10 +68,9 @@ struct EVLIndVarSimplify : public LoopPass {
6268};
6369} // anonymous namespace
6470
65- static std::optional<uint32_t > getVFFromIndVar (const SCEV *Step,
66- const Function &F) {
71+ static uint32_t getVFFromIndVar (const SCEV *Step, const Function &F) {
6772 if (!Step)
68- return std:: nullopt ;
73+ return 0U ;
6974
7075 // Looking for loops with IV step value in the form of `(<constant VF> x
7176 // vscale)`.
@@ -95,14 +100,18 @@ static std::optional<uint32_t> getVFFromIndVar(const SCEV *Step,
95100 }
96101 }
97102
98- return std:: nullopt ;
103+ return 0U ;
99104}
100105
101106// Remove the original induction variable if it's not used anywhere.
102- static void cleanupOriginalIndVar (PHINode *OrigIndVar, BasicBlock *InitBlock,
103- BasicBlock *BackEdgeBlock) {
104- Value *InitValue = OrigIndVar->getIncomingValueForBlock (InitBlock);
105- Value *RecValue = OrigIndVar->getIncomingValueForBlock (BackEdgeBlock);
107+ static void tryCleanupOriginalIndVar (PHINode *OrigIndVar,
108+ const InductionDescriptor &IVD) {
109+ if (OrigIndVar->getNumIncomingValues () != 2 )
110+ return ;
111+ Value *InitValue = OrigIndVar->getIncomingValue (0 );
112+ Value *RecValue = OrigIndVar->getIncomingValue (1 );
113+ if (InitValue != IVD.getStartValue ())
114+ std::swap (InitValue, RecValue);
106115
107116 // If the only user of OrigIndVar is the one produces RecValue, then we can
108117 // safely remove it.
@@ -117,6 +126,9 @@ static void cleanupOriginalIndVar(PHINode *OrigIndVar, BasicBlock *InitBlock,
117126}
118127
119128bool EVLIndVarSimplifyImpl::run (Loop &L) {
129+ if (!EnableEVLIndVarSimplify)
130+ return false ;
131+
120132 InductionDescriptor IVD;
121133 PHINode *IndVar = L.getInductionVariable (SE);
122134 if (!IndVar || !L.getInductionDescriptor (SE, IVD)) {
@@ -143,23 +155,23 @@ bool EVLIndVarSimplifyImpl::run(Loop &L) {
143155 Value *CanonicalIVFinal = &Bounds->getFinalIVValue ();
144156
145157 const SCEV *StepV = IVD.getStep ();
146- auto VF = getVFFromIndVar (StepV, *L.getHeader ()->getParent ());
158+ uint32_t VF = getVFFromIndVar (StepV, *L.getHeader ()->getParent ());
147159 if (!VF) {
148160 LLVM_DEBUG (dbgs () << " Could not infer VF from IndVar step '" << *StepV
149161 << " '\n " );
150162 return false ;
151163 }
152- LLVM_DEBUG (dbgs () << " Using VF=" << * VF << " for loop " << L.getName ()
164+ LLVM_DEBUG (dbgs () << " Using VF=" << VF << " for loop " << L.getName ()
153165 << " \n " );
154166
155167 // Try to find the EVL-based induction variable.
156168 using namespace PatternMatch ;
157169 BasicBlock *BB = IndVar->getParent ();
158170
159- Value *EVLIndex = nullptr ;
160- Value *RemVL = nullptr , *AVL = nullptr ;
171+ Value *EVLIndVar = nullptr ;
172+ Value *RemTC = nullptr , *TC = nullptr ;
161173 auto IntrinsicMatch = m_Intrinsic<Intrinsic::experimental_get_vector_length>(
162- m_Value (RemVL ), m_SpecificInt (* VF),
174+ m_Value (RemTC ), m_SpecificInt (VF),
163175 /* Scalable=*/ m_SpecificInt (1 ));
164176 for (auto &PN : BB->phis ()) {
165177 if (&PN == IndVar)
@@ -198,19 +210,19 @@ bool EVLIndVarSimplifyImpl::run(Loop &L) {
198210 << " \n " );
199211
200212 // Check 3: Pattern match to find the EVL-based index and total trip count
201- // (AVL ).
213+ // (TC ).
202214 if (match (RecValue,
203215 m_c_Add (m_ZExtOrSelf (IntrinsicMatch), m_Specific (&PN))) &&
204- match (RemVL , m_Sub (m_Value (AVL ), m_Specific (&PN)))) {
205- EVLIndex = RecValue;
216+ match (RemTC , m_Sub (m_Value (TC ), m_Specific (&PN)))) {
217+ EVLIndVar = RecValue;
206218 break ;
207219 }
208220 }
209221
210- if (!EVLIndex || !AVL )
222+ if (!EVLIndVar || !TC )
211223 return false ;
212224
213- LLVM_DEBUG (dbgs () << " Using " << *EVLIndex << " for EVL-based IndVar\n " );
225+ LLVM_DEBUG (dbgs () << " Using " << *EVLIndVar << " for EVL-based IndVar\n " );
214226
215227 // Create an EVL-based comparison and replace the branch to use it as
216228 // predicate.
@@ -220,10 +232,10 @@ bool EVLIndVarSimplifyImpl::run(Loop &L) {
220232 return false ;
221233
222234 IRBuilder<> Builder (OrigLatchCmp);
223- auto *NewPred = Builder.CreateICmp (Pred, EVLIndex, AVL );
235+ auto *NewPred = Builder.CreateICmp (Pred, EVLIndVar, TC );
224236 OrigLatchCmp->replaceAllUsesWith (NewPred);
225237
226- cleanupOriginalIndVar (IndVar, InitBlock, BackEdgeBlock );
238+ tryCleanupOriginalIndVar (IndVar, IVD );
227239
228240 ++NumEliminatedCanonicalIV;
229241
0 commit comments