Skip to content
Merged
5 changes: 5 additions & 0 deletions llvm/include/llvm/Analysis/ScalarEvolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -1316,6 +1316,11 @@ class ScalarEvolution {

LoopGuards(ScalarEvolution &SE) : SE(SE) {}

static LoopGuards
collectFromBlock(ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
const BasicBlock *Block, const BasicBlock *Pred,
SmallPtrSet<const BasicBlock *, 8> VisitedBlocks);

public:
/// Collect rewrite map for loop guards for loop \p L, together with flags
/// indicating if NUW and NSW can be preserved during rewriting.
Expand Down
85 changes: 76 additions & 9 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10648,7 +10648,7 @@ ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
if (const Loop *L = LI.getLoopFor(BB))
return {L->getLoopPredecessor(), L->getHeader()};

return {nullptr, nullptr};
return {nullptr, BB};
}

/// SCEV structural equivalence is usually sufficient for testing whether two
Expand Down Expand Up @@ -15217,7 +15217,16 @@ bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,

ScalarEvolution::LoopGuards
ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
BasicBlock *Header = L->getHeader();
BasicBlock *Pred = L->getLoopPredecessor();
LoopGuards Guards(SE);
return collectFromBlock(SE, Guards, Header, Pred, {});
}

ScalarEvolution::LoopGuards ScalarEvolution::LoopGuards::collectFromBlock(
ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
const BasicBlock *Block, const BasicBlock *Pred,
SmallPtrSet<const BasicBlock *, 8> VisitedBlocks) {
SmallVector<const SCEV *> ExprsToRewrite;
auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
const SCEV *RHS,
Expand Down Expand Up @@ -15556,14 +15565,13 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
}
};

BasicBlock *Header = L->getHeader();
SmallVector<PointerIntPair<Value *, 1, bool>> Terms;
// First, collect information from assumptions dominating the loop.
for (auto &AssumeVH : SE.AC.assumptions()) {
if (!AssumeVH)
continue;
auto *AssumeI = cast<CallInst>(AssumeVH);
if (!SE.DT.dominates(AssumeI, Header))
if (!SE.DT.dominates(AssumeI, Block))
continue;
Terms.emplace_back(AssumeI->getOperand(0), true);
}
Expand All @@ -15574,20 +15582,19 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
if (GuardDecl)
for (const auto *GU : GuardDecl->users())
if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
if (Guard->getFunction() == Header->getParent() &&
SE.DT.dominates(Guard, Header))
if (Guard->getFunction() == Block->getParent() &&
SE.DT.dominates(Guard, Block))
Terms.emplace_back(Guard->getArgOperand(0), true);

// Third, collect conditions from dominating branches. Starting at the loop
// predecessor, climb up the predecessor chain, as long as there are
// predecessors that can be found that have unique successors leading to the
// original header.
// TODO: share this logic with isLoopEntryGuardedByCond.
for (std::pair<const BasicBlock *, const BasicBlock *> Pair(
L->getLoopPredecessor(), Header);
Pair.first;
std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
for (; Pair.first;
Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {

VisitedBlocks.insert(Pair.second);
const BranchInst *LoopEntryPredicate =
dyn_cast<BranchInst>(Pair.first->getTerminator());
if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
Expand All @@ -15596,6 +15603,66 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
Terms.emplace_back(LoopEntryPredicate->getCondition(),
LoopEntryPredicate->getSuccessor(0) == Pair.second);
}
// Finally, if we stopped climbing the predecessor chain because
// there wasn't a unique one to continue, try to collect conditions
// for PHINodes by recursively following all of their incoming
// blocks and try to merge the found conditions to build a new one
// for the Phi.
if (Pair.second->hasNPredecessorsOrMore(2)) {
for (auto &Phi : Pair.second->phis()) {
if (!SE.isSCEVable(Phi.getType()))
continue;

using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
auto GetMinMaxConst = [&SE, &VisitedBlocks, &Pair,
&Phi](unsigned int In) -> MinMaxPattern {
LoopGuards G(SE);
if (VisitedBlocks.insert(Phi.getIncomingBlock(In)).second)
collectFromBlock(SE, G, Pair.second, Phi.getIncomingBlock(In),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably only allow a single level of recursion to start with, i.e. don't allow multiple predecessors after recursing here the first time

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, added MaxLoopGuardCollectionDepth and defaulted it to 1

VisitedBlocks);
const SCEV *S = G.RewriteMap[SE.getSCEV(Phi.getIncomingValue(In))];
auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S);
if (!SM)
return {nullptr, scCouldNotCompute};
if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
return {C0, SM->getSCEVType()};
if (const SCEVConstant *C1 = dyn_cast<SCEVConstant>(SM->getOperand(1)))
return {C1, SM->getSCEVType()};
return {nullptr, scCouldNotCompute};
};
auto MergeMinMaxConst = [](MinMaxPattern P1,
MinMaxPattern P2) -> MinMaxPattern {
auto [C1, T1] = P1;
auto [C2, T2] = P2;
if (!C1 || !C2 || T1 != T2)
return {nullptr, scCouldNotCompute};
switch (T1) {
case scUMaxExpr:
return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
case scSMaxExpr:
return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
case scUMinExpr:
return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
case scSMinExpr:
return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
default:
llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
}
};
auto P = GetMinMaxConst(0);
for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
if (!P.first)
break;
P = MergeMinMaxConst(P, GetMinMaxConst(In));
}
if (P.first) {
const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
SmallVector<const SCEV *, 2> Ops({P.first, LHS});
const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
Guards.RewriteMap.insert({LHS, RHS});
}
}
}

// Now apply the information from the collected conditions to
// Guards.RewriteMap. Conditions are processed in reverse order, so the
Expand Down
82 changes: 82 additions & 0 deletions llvm/test/Analysis/ScalarEvolution/trip-count.ll
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,85 @@ for.body:
exit:
ret void
}

define void @epilogue(i64 %count) {
; CHECK-LABEL: 'epilogue'
; CHECK-NEXT: Determining loop execution counts for: @epilogue
; CHECK-NEXT: Loop %epilogue: backedge-taken count is (-1 + %count.epilogue)
; CHECK-NEXT: Loop %epilogue: constant max backedge-taken count is i64 6
; CHECK-NEXT: Loop %epilogue: symbolic max backedge-taken count is (-1 + %count.epilogue)
; CHECK-NEXT: Loop %epilogue: Trip multiple is 1
; CHECK-NEXT: Loop %while.body: backedge-taken count is ((-8 + %count) /u 8)
; CHECK-NEXT: Loop %while.body: constant max backedge-taken count is i64 2305843009213693951
; CHECK-NEXT: Loop %while.body: symbolic max backedge-taken count is ((-8 + %count) /u 8)
; CHECK-NEXT: Loop %while.body: Trip multiple is 1
entry:
%cmp = icmp ugt i64 %count, 7
br i1 %cmp, label %while.body, label %epilogue.preheader

while.body:
%iv = phi i64 [ %sub, %while.body ], [ %count, %entry ]
%sub = add i64 %iv, -8
%exitcond.not = icmp ugt i64 %sub, 7
br i1 %exitcond.not, label %while.body, label %while.loopexit

while.loopexit:
%sub.exit = phi i64 [ %sub, %while.body ]
br label %epilogue.preheader

epilogue.preheader:
%count.epilogue = phi i64 [ %count, %entry ], [ %sub.exit, %while.loopexit ]
%epilogue.cmp = icmp eq i64 %count.epilogue, 0
br i1 %epilogue.cmp, label %exit, label %epilogue

epilogue:
%iv.epilogue = phi i64 [ %dec, %epilogue ], [ %count.epilogue, %epilogue.preheader ]
%dec = add i64 %iv.epilogue, -1
%exitcond.epilogue = icmp eq i64 %dec, 0
br i1 %exitcond.epilogue, label %exit, label %epilogue

exit:
ret void

}

define void @epilogue2(i64 %count) {
; CHECK-LABEL: 'epilogue2'
; CHECK-NEXT: Determining loop execution counts for: @epilogue2
; CHECK-NEXT: Loop %epilogue: backedge-taken count is (-1 + %count.epilogue)
; CHECK-NEXT: Loop %epilogue: constant max backedge-taken count is i64 8
; CHECK-NEXT: Loop %epilogue: symbolic max backedge-taken count is (-1 + %count.epilogue)
; CHECK-NEXT: Loop %epilogue: Trip multiple is 1
; CHECK-NEXT: Loop %while.body: backedge-taken count is ((-8 + %count) /u 8)
; CHECK-NEXT: Loop %while.body: constant max backedge-taken count is i64 2305843009213693951
; CHECK-NEXT: Loop %while.body: symbolic max backedge-taken count is ((-8 + %count) /u 8)
; CHECK-NEXT: Loop %while.body: Trip multiple is 1
entry:
%cmp = icmp ugt i64 %count, 9
br i1 %cmp, label %while.body, label %epilogue.preheader

while.body:
%iv = phi i64 [ %sub, %while.body ], [ %count, %entry ]
%sub = add i64 %iv, -8
%exitcond.not = icmp ugt i64 %sub, 7
br i1 %exitcond.not, label %while.body, label %while.loopexit

while.loopexit:
%sub.exit = phi i64 [ %sub, %while.body ]
br label %epilogue.preheader

epilogue.preheader:
%count.epilogue = phi i64 [ %count, %entry ], [ %sub.exit, %while.loopexit ]
%epilogue.cmp = icmp eq i64 %count.epilogue, 0
br i1 %epilogue.cmp, label %exit, label %epilogue

epilogue:
%iv.epilogue = phi i64 [ %dec, %epilogue ], [ %count.epilogue, %epilogue.preheader ]
%dec = add i64 %iv.epilogue, -1
%exitcond.epilogue = icmp eq i64 %dec, 0
br i1 %exitcond.epilogue, label %exit, label %epilogue

exit:
ret void

}
2 changes: 1 addition & 1 deletion llvm/test/Transforms/PhaseOrdering/X86/pr38280.ll
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ define void @apply_delta(ptr nocapture noundef %dst, ptr nocapture noundef reado
; CHECK-NEXT: [[INCDEC_PTR]] = getelementptr inbounds i8, ptr [[DST_ADDR_130]], i64 1
; CHECK-NEXT: [[INCDEC_PTR8]] = getelementptr inbounds i8, ptr [[SRC_ADDR_129]], i64 1
; CHECK-NEXT: [[TOBOOL_NOT:%.*]] = icmp eq i64 [[DEC]], 0
; CHECK-NEXT: br i1 [[TOBOOL_NOT]], label [[WHILE_END9]], label [[WHILE_BODY4]], !llvm.loop [[LOOP0:![0-9]+]]
; CHECK-NEXT: br i1 [[TOBOOL_NOT]], label [[WHILE_END9]], label [[WHILE_BODY4]]
; CHECK: while.end9:
; CHECK-NEXT: ret void
;
Expand Down
Loading