@@ -10648,7 +10648,7 @@ ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
1064810648 if (const Loop *L = LI.getLoopFor(BB))
1064910649 return {L->getLoopPredecessor(), L->getHeader()};
1065010650
10651- return {nullptr, nullptr };
10651+ return {nullptr, BB };
1065210652}
1065310653
1065410654/// SCEV structural equivalence is usually sufficient for testing whether two
@@ -15217,7 +15217,16 @@ bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
1521715217
1521815218ScalarEvolution::LoopGuards
1521915219ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
15220+ BasicBlock *Header = L->getHeader();
15221+ BasicBlock *Pred = L->getLoopPredecessor();
1522015222 LoopGuards Guards(SE);
15223+ return collectFromBlock(SE, Guards, Header, Pred, {});
15224+ }
15225+
15226+ ScalarEvolution::LoopGuards ScalarEvolution::LoopGuards::collectFromBlock(
15227+ ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15228+ const BasicBlock *Block, const BasicBlock *Pred,
15229+ SmallPtrSet<const BasicBlock *, 8> VisitedBlocks) {
1522115230 SmallVector<const SCEV *> ExprsToRewrite;
1522215231 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
1522315232 const SCEV *RHS,
@@ -15556,14 +15565,13 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
1555615565 }
1555715566 };
1555815567
15559- BasicBlock *Header = L->getHeader();
1556015568 SmallVector<PointerIntPair<Value *, 1, bool>> Terms;
1556115569 // First, collect information from assumptions dominating the loop.
1556215570 for (auto &AssumeVH : SE.AC.assumptions()) {
1556315571 if (!AssumeVH)
1556415572 continue;
1556515573 auto *AssumeI = cast<CallInst>(AssumeVH);
15566- if (!SE.DT.dominates(AssumeI, Header ))
15574+ if (!SE.DT.dominates(AssumeI, Block ))
1556715575 continue;
1556815576 Terms.emplace_back(AssumeI->getOperand(0), true);
1556915577 }
@@ -15574,20 +15582,19 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
1557415582 if (GuardDecl)
1557515583 for (const auto *GU : GuardDecl->users())
1557615584 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15577- if (Guard->getFunction() == Header ->getParent() &&
15578- SE.DT.dominates(Guard, Header ))
15585+ if (Guard->getFunction() == Block ->getParent() &&
15586+ SE.DT.dominates(Guard, Block ))
1557915587 Terms.emplace_back(Guard->getArgOperand(0), true);
1558015588
1558115589 // Third, collect conditions from dominating branches. Starting at the loop
1558215590 // predecessor, climb up the predecessor chain, as long as there are
1558315591 // predecessors that can be found that have unique successors leading to the
1558415592 // original header.
1558515593 // TODO: share this logic with isLoopEntryGuardedByCond.
15586- for (std::pair<const BasicBlock *, const BasicBlock *> Pair(
15587- L->getLoopPredecessor(), Header);
15588- Pair.first;
15594+ std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
15595+ for (; Pair.first;
1558915596 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15590-
15597+ VisitedBlocks.insert(Pair.second);
1559115598 const BranchInst *LoopEntryPredicate =
1559215599 dyn_cast<BranchInst>(Pair.first->getTerminator());
1559315600 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
@@ -15596,6 +15603,66 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
1559615603 Terms.emplace_back(LoopEntryPredicate->getCondition(),
1559715604 LoopEntryPredicate->getSuccessor(0) == Pair.second);
1559815605 }
15606+ // Finally, if we stopped climbing the predecessor chain because
15607+ // there wasn't a unique one to continue, try to collect conditions
15608+ // for PHINodes by recursively following all of their incoming
15609+ // blocks and try to merge the found conditions to build a new one
15610+ // for the Phi.
15611+ if (Pair.second->hasNPredecessorsOrMore(2)) {
15612+ for (auto &Phi : Pair.second->phis()) {
15613+ if (!SE.isSCEVable(Phi.getType()))
15614+ continue;
15615+
15616+ using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15617+ auto GetMinMaxConst = [&SE, &VisitedBlocks, &Pair,
15618+ &Phi](unsigned int In) -> MinMaxPattern {
15619+ LoopGuards G(SE);
15620+ if (VisitedBlocks.insert(Phi.getIncomingBlock(In)).second)
15621+ collectFromBlock(SE, G, Pair.second, Phi.getIncomingBlock(In),
15622+ VisitedBlocks);
15623+ const SCEV *S = G.RewriteMap[SE.getSCEV(Phi.getIncomingValue(In))];
15624+ auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S);
15625+ if (!SM)
15626+ return {nullptr, scCouldNotCompute};
15627+ if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15628+ return {C0, SM->getSCEVType()};
15629+ if (const SCEVConstant *C1 = dyn_cast<SCEVConstant>(SM->getOperand(1)))
15630+ return {C1, SM->getSCEVType()};
15631+ return {nullptr, scCouldNotCompute};
15632+ };
15633+ auto MergeMinMaxConst = [](MinMaxPattern P1,
15634+ MinMaxPattern P2) -> MinMaxPattern {
15635+ auto [C1, T1] = P1;
15636+ auto [C2, T2] = P2;
15637+ if (!C1 || !C2 || T1 != T2)
15638+ return {nullptr, scCouldNotCompute};
15639+ switch (T1) {
15640+ case scUMaxExpr:
15641+ return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15642+ case scSMaxExpr:
15643+ return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15644+ case scUMinExpr:
15645+ return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15646+ case scSMinExpr:
15647+ return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15648+ default:
15649+ llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15650+ }
15651+ };
15652+ auto P = GetMinMaxConst(0);
15653+ for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15654+ if (!P.first)
15655+ break;
15656+ P = MergeMinMaxConst(P, GetMinMaxConst(In));
15657+ }
15658+ if (P.first) {
15659+ const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15660+ SmallVector<const SCEV *, 2> Ops({P.first, LHS});
15661+ const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15662+ Guards.RewriteMap.insert({LHS, RHS});
15663+ }
15664+ }
15665+ }
1559915666
1560015667 // Now apply the information from the collected conditions to
1560115668 // Guards.RewriteMap. Conditions are processed in reverse order, so the
0 commit comments