Skip to content

Commit 8f04feb

Browse files
committed
(Stash) Put EVLIVSimplify Pass at the end of the vectorizer Pass
In order to simplify the exit condition.
1 parent 73c651f commit 8f04feb

File tree

3 files changed

+24
-26
lines changed

3 files changed

+24
-26
lines changed

llvm/lib/CodeGen/EVLIndVarSimplify.cpp

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,11 @@ bool EVLIndVarSimplifyImpl::run(Loop &L) {
112112
if (!EnableEVLIndVarSimplify)
113113
return false;
114114

115+
BasicBlock *LatchBlock = L.getLoopLatch();
116+
ICmpInst *OrigLatchCmp = L.getLatchCmpInst();
117+
if (!LatchBlock || !OrigLatchCmp)
118+
return false;
119+
115120
InductionDescriptor IVD;
116121
PHINode *IndVar = L.getInductionVariable(SE);
117122
if (!IndVar || !L.getInductionDescriptor(SE, IVD)) {
@@ -153,6 +158,7 @@ bool EVLIndVarSimplifyImpl::run(Loop &L) {
153158

154159
Value *EVLIndVar = nullptr;
155160
Value *RemTC = nullptr;
161+
Value *TC = nullptr;
156162
auto IntrinsicMatch = m_Intrinsic<Intrinsic::experimental_get_vector_length>(
157163
m_Value(RemTC), m_SpecificInt(VF),
158164
/*Scalable=*/m_SpecificInt(1));
@@ -192,43 +198,37 @@ bool EVLIndVarSimplifyImpl::run(Loop &L) {
192198
LLVM_DEBUG(dbgs() << "Found candidate PN of EVL-based IndVar: " << PN
193199
<< "\n");
194200

195-
// Check 3: Pattern match to find the EVL-based index.
201+
// Check 3: Pattern match to find the EVL-based index and total trip count
202+
// (TC).
196203
if (match(RecValue,
197204
m_c_Add(m_ZExtOrSelf(IntrinsicMatch), m_Specific(&PN))) &&
198-
match(RemTC, m_Sub(m_Value(), m_Specific(&PN)))) {
205+
match(RemTC, m_Sub(m_Value(TC), m_Specific(&PN)))) {
199206
EVLIndVar = RecValue;
200207
break;
201208
}
202209
}
203210

204-
if (!EVLIndVar)
205-
return false;
206-
207-
const SCEV *BTC = SE.getBackedgeTakenCount(&L);
208-
LLVM_DEBUG(dbgs() << "BTC: " << *BTC << "\n");
209-
if (isa<SCEVCouldNotCompute>(BTC))
211+
if (!EVLIndVar || !TC)
210212
return false;
211213

212-
const SCEV *VFV = SE.getConstant(BTC->getType(), VF);
213-
VFV = SE.getMulExpr(VFV, SE.getVScale(VFV->getType()));
214-
const SCEV *ExitValV = SE.getMulExpr(BTC, VFV);
215-
LLVM_DEBUG(dbgs() << "ExitVal: " << *ExitValV << "\n");
214+
LLVM_DEBUG(dbgs() << "Using " << *EVLIndVar << " for EVL-based IndVar\n");
216215

217216
// Create an EVL-based comparison and replace the branch to use it as
218217
// predicate.
219-
ICmpInst *OrigLatchCmp = L.getLatchCmpInst();
220-
const DataLayout &DL = L.getHeader()->getDataLayout();
221-
SCEVExpander Expander(SE, DL, "evl.iv.exitcondition");
222-
if (!Expander.isSafeToExpandAt(ExitValV, OrigLatchCmp))
223-
return false;
224218

225-
LLVM_DEBUG(dbgs() << "Using " << *EVLIndVar << " for EVL-based IndVar\n");
219+
// Loop::getLatchCmpInst check at the beginning of this function has ensured
220+
// that latch block ends in a conditional branch.
221+
auto *LatchBranch = cast<BranchInst>(LatchBlock->getTerminator());
222+
assert(LatchBranch->getNumSuccessors() == 2);
223+
ICmpInst::Predicate Pred;
224+
if (LatchBranch->getSuccessor(0) == L.getHeader())
225+
Pred = ICmpInst::ICMP_ULT;
226+
else
227+
Pred = ICmpInst::ICMP_UGE;
226228

227-
Value *ExitVal =
228-
Expander.expandCodeFor(ExitValV, EVLIndVar->getType(), OrigLatchCmp);
229229
IRBuilder<> Builder(OrigLatchCmp);
230-
auto *NewPred = Builder.CreateICmp(ICmpInst::ICMP_UGT, EVLIndVar, ExitVal);
231-
OrigLatchCmp->replaceAllUsesWith(NewPred);
230+
auto *NewLatchCmp = Builder.CreateICmp(Pred, EVLIndVar, TC);
231+
OrigLatchCmp->replaceAllUsesWith(NewLatchCmp);
232232

233233
// llvm::RecursivelyDeleteDeadPHINode only deletes cycles whose values are
234234
// not used outside the cycles. However, in this case the now-RAUW-ed

llvm/lib/Passes/PassBuilderPipelines.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,6 +1272,8 @@ void PassBuilder::addVectorPasses(OptimizationLevel Level,
12721272
FPM.addPass(LoopVectorizePass(
12731273
LoopVectorizeOptions(!PTO.LoopInterleaving, !PTO.LoopVectorization)));
12741274

1275+
FPM.addPass(createFunctionToLoopPassAdaptor(EVLIndVarSimplifyPass()));
1276+
12751277
FPM.addPass(InferAlignmentPass());
12761278
if (IsFullLTO) {
12771279
// The vectorizer may have significantly shortened a loop body; unroll

llvm/lib/Target/RISCV/RISCVTargetMachine.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
#include "TargetInfo/RISCVTargetInfo.h"
2020
#include "llvm/ADT/STLExtras.h"
2121
#include "llvm/Analysis/TargetTransformInfo.h"
22-
#include "llvm/CodeGen/EVLIndVarSimplify.h"
2322
#include "llvm/CodeGen/GlobalISel/CSEInfo.h"
2423
#include "llvm/CodeGen/GlobalISel/IRTranslator.h"
2524
#include "llvm/CodeGen/GlobalISel/InstructionSelect.h"
@@ -467,9 +466,6 @@ void RISCVPassConfig::addIRPasses() {
467466
}
468467

469468
TargetPassConfig::addIRPasses();
470-
471-
if (getOptLevel() != CodeGenOptLevel::None)
472-
addPass(createEVLIndVarSimplifyPass());
473469
}
474470

475471
bool RISCVPassConfig::addPreISel() {

0 commit comments

Comments
 (0)