@@ -222,6 +222,10 @@ static cl::opt<unsigned> RangeIterThreshold(
222222 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
223223 cl::init(32));
224224
225+ static cl::opt<unsigned> MaxLoopGuardCollectionDepth(
226+ "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
227+ cl::desc("Maximum depth for recrusive loop guard collection"), cl::init(1));
228+
225229static cl::opt<bool>
226230ClassifyExpressions("scalar-evolution-classify-expressions",
227231 cl::Hidden, cl::init(true),
@@ -10666,7 +10670,7 @@ ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
1066610670 if (const Loop *L = LI.getLoopFor(BB))
1066710671 return {L->getLoopPredecessor(), L->getHeader()};
1066810672
10669- return {nullptr, nullptr };
10673+ return {nullptr, BB };
1067010674}
1067110675
1067210676/// SCEV structural equivalence is usually sufficient for testing whether two
@@ -15245,7 +15249,81 @@ bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
1524515249
1524615250ScalarEvolution::LoopGuards
1524715251ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
15252+ BasicBlock *Header = L->getHeader();
15253+ BasicBlock *Pred = L->getLoopPredecessor();
1524815254 LoopGuards Guards(SE);
15255+ SmallPtrSet<const BasicBlock *, 8> VisitedBlocks;
15256+ collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15257+ return Guards;
15258+ }
15259+
15260+ void ScalarEvolution::LoopGuards::collectFromPHI(
15261+ ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15262+ const PHINode &Phi, SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
15263+ SmallDenseMap<const BasicBlock *, LoopGuards> &IncomingGuards,
15264+ unsigned Depth) {
15265+ if (!SE.isSCEVable(Phi.getType()))
15266+ return;
15267+
15268+ using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15269+ auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15270+ const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15271+ if (!VisitedBlocks.insert(InBlock).second)
15272+ return {nullptr, scCouldNotCompute};
15273+ auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15274+ if (Inserted)
15275+ collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15276+ Depth + 1);
15277+ auto S = G->second.RewriteMap.find(
15278+ SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15279+ if (S == G->second.RewriteMap.end())
15280+ return {nullptr, scCouldNotCompute};
15281+ auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15282+ if (!SM)
15283+ return {nullptr, scCouldNotCompute};
15284+ if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15285+ return {C0, SM->getSCEVType()};
15286+ if (const SCEVConstant *C1 = dyn_cast<SCEVConstant>(SM->getOperand(1)))
15287+ return {C1, SM->getSCEVType()};
15288+ return {nullptr, scCouldNotCompute};
15289+ };
15290+ auto MergeMinMaxConst = [](MinMaxPattern P1,
15291+ MinMaxPattern P2) -> MinMaxPattern {
15292+ auto [C1, T1] = P1;
15293+ auto [C2, T2] = P2;
15294+ if (!C1 || !C2 || T1 != T2)
15295+ return {nullptr, scCouldNotCompute};
15296+ switch (T1) {
15297+ case scUMaxExpr:
15298+ return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15299+ case scSMaxExpr:
15300+ return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15301+ case scUMinExpr:
15302+ return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15303+ case scSMinExpr:
15304+ return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15305+ default:
15306+ llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15307+ }
15308+ };
15309+ auto P = GetMinMaxConst(0);
15310+ for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15311+ if (!P.first)
15312+ break;
15313+ P = MergeMinMaxConst(P, GetMinMaxConst(In));
15314+ }
15315+ if (P.first) {
15316+ const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15317+ SmallVector<const SCEV *, 2> Ops({P.first, LHS});
15318+ const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15319+ Guards.RewriteMap.insert({LHS, RHS});
15320+ }
15321+ }
15322+
15323+ void ScalarEvolution::LoopGuards::collectFromBlock(
15324+ ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15325+ const BasicBlock *Block, const BasicBlock *Pred,
15326+ SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
1524915327 SmallVector<const SCEV *> ExprsToRewrite;
1525015328 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
1525115329 const SCEV *RHS,
@@ -15584,14 +15662,13 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
1558415662 }
1558515663 };
1558615664
15587- BasicBlock *Header = L->getHeader();
1558815665 SmallVector<PointerIntPair<Value *, 1, bool>> Terms;
1558915666 // First, collect information from assumptions dominating the loop.
1559015667 for (auto &AssumeVH : SE.AC.assumptions()) {
1559115668 if (!AssumeVH)
1559215669 continue;
1559315670 auto *AssumeI = cast<CallInst>(AssumeVH);
15594- if (!SE.DT.dominates(AssumeI, Header ))
15671+ if (!SE.DT.dominates(AssumeI, Block ))
1559515672 continue;
1559615673 Terms.emplace_back(AssumeI->getOperand(0), true);
1559715674 }
@@ -15602,27 +15679,42 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
1560215679 if (GuardDecl)
1560315680 for (const auto *GU : GuardDecl->users())
1560415681 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15605- if (Guard->getFunction() == Header ->getParent() &&
15606- SE.DT.dominates(Guard, Header ))
15682+ if (Guard->getFunction() == Block ->getParent() &&
15683+ SE.DT.dominates(Guard, Block ))
1560715684 Terms.emplace_back(Guard->getArgOperand(0), true);
1560815685
1560915686 // Third, collect conditions from dominating branches. Starting at the loop
1561015687 // predecessor, climb up the predecessor chain, as long as there are
1561115688 // predecessors that can be found that have unique successors leading to the
1561215689 // original header.
1561315690 // TODO: share this logic with isLoopEntryGuardedByCond.
15614- for (std::pair<const BasicBlock *, const BasicBlock *> Pair(
15615- L->getLoopPredecessor(), Header);
15616- Pair.first;
15691+ std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
15692+ for (; Pair.first;
1561715693 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15618-
15694+ VisitedBlocks.insert(Pair.second);
1561915695 const BranchInst *LoopEntryPredicate =
1562015696 dyn_cast<BranchInst>(Pair.first->getTerminator());
1562115697 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
1562215698 continue;
1562315699
1562415700 Terms.emplace_back(LoopEntryPredicate->getCondition(),
1562515701 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15702+
15703+ // If we are recursively collecting guards stop after 2
15704+ // predecessors to limit compile-time impact for now.
15705+ if (Depth > 0 && Terms.size() == 2)
15706+ break;
15707+ }
15708+ // Finally, if we stopped climbing the predecessor chain because
15709+ // there wasn't a unique one to continue, try to collect conditions
15710+ // for PHINodes by recursively following all of their incoming
15711+ // blocks and try to merge the found conditions to build a new one
15712+ // for the Phi.
15713+ if (Pair.second->hasNPredecessorsOrMore(2) &&
15714+ Depth < MaxLoopGuardCollectionDepth) {
15715+ SmallDenseMap<const BasicBlock *, LoopGuards> IncomingGuards;
15716+ for (auto &Phi : Pair.second->phis())
15717+ collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
1562615718 }
1562715719
1562815720 // Now apply the information from the collected conditions to
@@ -15679,7 +15771,6 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
1567915771 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
1568015772 }
1568115773 }
15682- return Guards;
1568315774}
1568415775
1568515776const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
0 commit comments