Skip to content

Commit 8d38196

Browse files
committed
fixup! fixup! [SCEV] Collect and merge loop guards through PHI nodes with multiple incoming Values
1 parent b5d72e3 commit 8d38196

File tree

4 files changed

+296
-272
lines changed

4 files changed

+296
-272
lines changed

llvm/include/llvm/Analysis/ScalarEvolution.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,10 +1316,24 @@ class ScalarEvolution {
13161316

13171317
LoopGuards(ScalarEvolution &SE) : SE(SE) {}
13181318

1319-
static LoopGuards
1319+
/// Recursively collect loop guards in \p Guards, starting from
1320+
/// block \p Block with predecessor \p Pred. The intended starting point
1321+
/// is to collect from a loop header and its predecessor.
1322+
static void
13201323
collectFromBlock(ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
13211324
const BasicBlock *Block, const BasicBlock *Pred,
1322-
SmallPtrSet<const BasicBlock *, 8> VisitedBlocks);
1325+
SmallPtrSet<const BasicBlock *, 8> &VisitedBlocks,
1326+
unsigned Depth = 0);
1327+
1328+
/// Collect loop guards in \p Guards, starting from PHINode \p
1329+
/// Phi, by calling \p collectFromBlock on the incoming blocks of
1330+
/// \Phi and trying to merge the found constraints into a single
1331+
/// combined on for \p Phi.
1332+
static void
1333+
collectFromPHI(ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
1334+
const PHINode &Phi,
1335+
SmallPtrSet<const BasicBlock *, 8> &VisitedBlocks,
1336+
unsigned Depth);
13231337

13241338
public:
13251339
/// Collect rewrite map for loop guards for loop \p L, together with flags

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 70 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,11 @@ static cl::opt<unsigned> RangeIterThreshold(
222222
cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
223223
cl::init(32));
224224

225+
static cl::opt<unsigned> MaxLoopGuardCollectionDepth(
226+
"scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
227+
cl::desc("Maximum depth for recrusive loop guard collection"),
228+
cl::init(1));
229+
225230
static cl::opt<bool>
226231
ClassifyExpressions("scalar-evolution-classify-expressions",
227232
cl::Hidden, cl::init(true),
@@ -15220,13 +15225,72 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
1522015225
BasicBlock *Header = L->getHeader();
1522115226
BasicBlock *Pred = L->getLoopPredecessor();
1522215227
LoopGuards Guards(SE);
15223-
return collectFromBlock(SE, Guards, Header, Pred, {});
15228+
SmallPtrSet<const BasicBlock *, 8> VisitedBlocks;
15229+
collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15230+
return Guards;
15231+
}
15232+
15233+
void ScalarEvolution::LoopGuards::collectFromPHI(
15234+
ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15235+
const PHINode &Phi, SmallPtrSet<const BasicBlock *, 8> &VisitedBlocks,
15236+
unsigned Depth) {
15237+
if (!SE.isSCEVable(Phi.getType()))
15238+
return;
15239+
15240+
using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15241+
auto GetMinMaxConst = [&](unsigned In) -> MinMaxPattern {
15242+
if (!VisitedBlocks.insert(Phi.getIncomingBlock(In)).second)
15243+
return {nullptr, scCouldNotCompute};
15244+
LoopGuards G(SE);
15245+
collectFromBlock(SE, G, Phi.getParent(), Phi.getIncomingBlock(In),
15246+
VisitedBlocks, Depth + 1);
15247+
const SCEV *S = G.RewriteMap[SE.getSCEV(Phi.getIncomingValue(In))];
15248+
auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S);
15249+
if (!SM)
15250+
return {nullptr, scCouldNotCompute};
15251+
if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15252+
return {C0, SM->getSCEVType()};
15253+
if (const SCEVConstant *C1 = dyn_cast<SCEVConstant>(SM->getOperand(1)))
15254+
return {C1, SM->getSCEVType()};
15255+
return {nullptr, scCouldNotCompute};
15256+
};
15257+
auto MergeMinMaxConst = [](MinMaxPattern P1,
15258+
MinMaxPattern P2) -> MinMaxPattern {
15259+
auto [C1, T1] = P1;
15260+
auto [C2, T2] = P2;
15261+
if (!C1 || !C2 || T1 != T2)
15262+
return {nullptr, scCouldNotCompute};
15263+
switch (T1) {
15264+
case scUMaxExpr:
15265+
return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15266+
case scSMaxExpr:
15267+
return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15268+
case scUMinExpr:
15269+
return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15270+
case scSMinExpr:
15271+
return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15272+
default:
15273+
llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15274+
}
15275+
};
15276+
auto P = GetMinMaxConst(0);
15277+
for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15278+
if (!P.first)
15279+
break;
15280+
P = MergeMinMaxConst(P, GetMinMaxConst(In));
15281+
}
15282+
if (P.first) {
15283+
const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15284+
SmallVector<const SCEV *, 2> Ops({P.first, LHS});
15285+
const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15286+
Guards.RewriteMap.insert({LHS, RHS});
15287+
}
1522415288
}
1522515289

15226-
ScalarEvolution::LoopGuards ScalarEvolution::LoopGuards::collectFromBlock(
15290+
void ScalarEvolution::LoopGuards::collectFromBlock(
1522715291
ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
1522815292
const BasicBlock *Block, const BasicBlock *Pred,
15229-
SmallPtrSet<const BasicBlock *, 8> VisitedBlocks) {
15293+
SmallPtrSet<const BasicBlock *, 8> &VisitedBlocks, unsigned Depth) {
1523015294
SmallVector<const SCEV *> ExprsToRewrite;
1523115295
auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
1523215296
const SCEV *RHS,
@@ -15608,59 +15672,10 @@ ScalarEvolution::LoopGuards ScalarEvolution::LoopGuards::collectFromBlock(
1560815672
// for PHINodes by recursively following all of their incoming
1560915673
// blocks and try to merge the found conditions to build a new one
1561015674
// for the Phi.
15611-
if (Pair.second->hasNPredecessorsOrMore(2)) {
15675+
if (Pair.second->hasNPredecessorsOrMore(2) &&
15676+
Depth < MaxLoopGuardCollectionDepth) {
1561215677
for (auto &Phi : Pair.second->phis()) {
15613-
if (!SE.isSCEVable(Phi.getType()))
15614-
continue;
15615-
15616-
using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15617-
auto GetMinMaxConst = [&SE, &VisitedBlocks, &Pair,
15618-
&Phi](unsigned int In) -> MinMaxPattern {
15619-
LoopGuards G(SE);
15620-
if (VisitedBlocks.insert(Phi.getIncomingBlock(In)).second)
15621-
collectFromBlock(SE, G, Pair.second, Phi.getIncomingBlock(In),
15622-
VisitedBlocks);
15623-
const SCEV *S = G.RewriteMap[SE.getSCEV(Phi.getIncomingValue(In))];
15624-
auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S);
15625-
if (!SM)
15626-
return {nullptr, scCouldNotCompute};
15627-
if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15628-
return {C0, SM->getSCEVType()};
15629-
if (const SCEVConstant *C1 = dyn_cast<SCEVConstant>(SM->getOperand(1)))
15630-
return {C1, SM->getSCEVType()};
15631-
return {nullptr, scCouldNotCompute};
15632-
};
15633-
auto MergeMinMaxConst = [](MinMaxPattern P1,
15634-
MinMaxPattern P2) -> MinMaxPattern {
15635-
auto [C1, T1] = P1;
15636-
auto [C2, T2] = P2;
15637-
if (!C1 || !C2 || T1 != T2)
15638-
return {nullptr, scCouldNotCompute};
15639-
switch (T1) {
15640-
case scUMaxExpr:
15641-
return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15642-
case scSMaxExpr:
15643-
return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15644-
case scUMinExpr:
15645-
return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15646-
case scSMinExpr:
15647-
return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15648-
default:
15649-
llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15650-
}
15651-
};
15652-
auto P = GetMinMaxConst(0);
15653-
for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15654-
if (!P.first)
15655-
break;
15656-
P = MergeMinMaxConst(P, GetMinMaxConst(In));
15657-
}
15658-
if (P.first) {
15659-
const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15660-
SmallVector<const SCEV *, 2> Ops({P.first, LHS});
15661-
const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15662-
Guards.RewriteMap.insert({LHS, RHS});
15663-
}
15678+
collectFromPHI(SE, Guards, Phi, VisitedBlocks, Depth);
1566415679
}
1566515680
}
1566615681

@@ -15718,7 +15733,6 @@ ScalarEvolution::LoopGuards ScalarEvolution::LoopGuards::collectFromBlock(
1571815733
Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
1571915734
}
1572015735
}
15721-
return Guards;
1572215736
}
1572315737

1572415738
const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
; RUN: opt < %s -disable-output "-passes=print<scalar-evolution>" -scalar-evolution-max-iterations=0 -scalar-evolution-classify-expressions=0 2>&1 | FileCheck %s
2+
3+
define void @epilogue(i64 %count) {
4+
; CHECK-LABEL: 'epilogue'
5+
; CHECK-NEXT: Determining loop execution counts for: @epilogue
6+
; CHECK-NEXT: Loop %epilogue: backedge-taken count is (-1 + %count.epilogue)
7+
; CHECK-NEXT: Loop %epilogue: constant max backedge-taken count is i64 6
8+
; CHECK-NEXT: Loop %epilogue: symbolic max backedge-taken count is (-1 + %count.epilogue)
9+
; CHECK-NEXT: Loop %epilogue: Trip multiple is 1
10+
; CHECK-NEXT: Loop %while.body: backedge-taken count is ((-8 + %count) /u 8)
11+
; CHECK-NEXT: Loop %while.body: constant max backedge-taken count is i64 2305843009213693951
12+
; CHECK-NEXT: Loop %while.body: symbolic max backedge-taken count is ((-8 + %count) /u 8)
13+
; CHECK-NEXT: Loop %while.body: Trip multiple is 1
14+
entry:
15+
%cmp = icmp ugt i64 %count, 7
16+
br i1 %cmp, label %while.body, label %epilogue.preheader
17+
18+
while.body:
19+
%iv = phi i64 [ %sub, %while.body ], [ %count, %entry ]
20+
%sub = add i64 %iv, -8
21+
%exitcond.not = icmp ugt i64 %sub, 7
22+
br i1 %exitcond.not, label %while.body, label %while.loopexit
23+
24+
while.loopexit:
25+
%sub.exit = phi i64 [ %sub, %while.body ]
26+
br label %epilogue.preheader
27+
28+
epilogue.preheader:
29+
%count.epilogue = phi i64 [ %count, %entry ], [ %sub.exit, %while.loopexit ]
30+
%epilogue.cmp = icmp eq i64 %count.epilogue, 0
31+
br i1 %epilogue.cmp, label %exit, label %epilogue
32+
33+
epilogue:
34+
%iv.epilogue = phi i64 [ %dec, %epilogue ], [ %count.epilogue, %epilogue.preheader ]
35+
%dec = add i64 %iv.epilogue, -1
36+
%exitcond.epilogue = icmp eq i64 %dec, 0
37+
br i1 %exitcond.epilogue, label %exit, label %epilogue
38+
39+
exit:
40+
ret void
41+
}
42+
43+
define void @epilogue2(i64 %count) {
44+
; CHECK-LABEL: 'epilogue2'
45+
; CHECK-NEXT: Determining loop execution counts for: @epilogue2
46+
; CHECK-NEXT: Loop %epilogue: backedge-taken count is (-1 + %count.epilogue)
47+
; CHECK-NEXT: Loop %epilogue: constant max backedge-taken count is i64 8
48+
; CHECK-NEXT: Loop %epilogue: symbolic max backedge-taken count is (-1 + %count.epilogue)
49+
; CHECK-NEXT: Loop %epilogue: Trip multiple is 1
50+
; CHECK-NEXT: Loop %while.body: backedge-taken count is ((-8 + %count) /u 8)
51+
; CHECK-NEXT: Loop %while.body: constant max backedge-taken count is i64 2305843009213693951
52+
; CHECK-NEXT: Loop %while.body: symbolic max backedge-taken count is ((-8 + %count) /u 8)
53+
; CHECK-NEXT: Loop %while.body: Trip multiple is 1
54+
entry:
55+
%cmp = icmp ugt i64 %count, 9
56+
br i1 %cmp, label %while.body, label %epilogue.preheader
57+
58+
while.body:
59+
%iv = phi i64 [ %sub, %while.body ], [ %count, %entry ]
60+
%sub = add i64 %iv, -8
61+
%exitcond.not = icmp ugt i64 %sub, 7
62+
br i1 %exitcond.not, label %while.body, label %while.loopexit
63+
64+
while.loopexit:
65+
%sub.exit = phi i64 [ %sub, %while.body ]
66+
br label %epilogue.preheader
67+
68+
epilogue.preheader:
69+
%count.epilogue = phi i64 [ %count, %entry ], [ %sub.exit, %while.loopexit ]
70+
%epilogue.cmp = icmp eq i64 %count.epilogue, 0
71+
br i1 %epilogue.cmp, label %exit, label %epilogue
72+
73+
epilogue:
74+
%iv.epilogue = phi i64 [ %dec, %epilogue ], [ %count.epilogue, %epilogue.preheader ]
75+
%dec = add i64 %iv.epilogue, -1
76+
%exitcond.epilogue = icmp eq i64 %dec, 0
77+
br i1 %exitcond.epilogue, label %exit, label %epilogue
78+
79+
exit:
80+
ret void
81+
}
82+
83+
define void @slt(i16 %a, i16 %b, i1 %c) {
84+
; CHECK-LABEL: 'slt'
85+
; CHECK-NEXT: Determining loop execution counts for: @slt
86+
; CHECK-NEXT: Loop %loop: backedge-taken count is (63 + (-1 * %count))
87+
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i16 -32704
88+
; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is (63 + (-1 * %count))
89+
; CHECK-NEXT: Loop %loop: Trip multiple is 1
90+
entry:
91+
br i1 %c, label %b1, label %b2
92+
93+
b1:
94+
%cmp1 = icmp slt i16 %a, 8
95+
br i1 %cmp1, label %preheader, label %exit
96+
97+
b2:
98+
%cmp2 = icmp slt i16 %b, 8
99+
br i1 %cmp2, label %preheader, label %exit
100+
101+
preheader:
102+
%count = phi i16 [ %a, %b1 ], [ %b, %b2 ]
103+
br label %loop
104+
105+
loop:
106+
%iv = phi i16 [ %iv.next, %loop ], [ %count, %preheader ]
107+
%iv.next = add i16 %iv, 1
108+
%exitcond = icmp slt i16 %iv.next, 64
109+
br i1 %exitcond, label %loop, label %exit
110+
111+
exit:
112+
ret void
113+
}
114+
115+
define void @ult(i16 %a, i16 %b, i1 %c) {
116+
; CHECK-LABEL: 'ult'
117+
; CHECK-NEXT: Determining loop execution counts for: @ult
118+
; CHECK-NEXT: Loop %loop: backedge-taken count is (-1 + %count)
119+
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i16 -2
120+
; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is (-1 + %count)
121+
; CHECK-NEXT: Loop %loop: Trip multiple is 1
122+
entry:
123+
br i1 %c, label %b1, label %b2
124+
125+
b1:
126+
%cmp1 = icmp ult i16 %a, 8
127+
br i1 %cmp1, label %exit, label %preheader
128+
129+
b2:
130+
%cmp2 = icmp ult i16 %b, 8
131+
br i1 %cmp2, label %exit, label %preheader
132+
133+
preheader:
134+
%count = phi i16 [ %a, %b1 ], [ %b, %b2 ]
135+
br label %loop
136+
137+
loop:
138+
%iv = phi i16 [ %iv.next, %loop ], [ %count, %preheader ]
139+
%iv.next = add i16 %iv, -1
140+
%exitcond = icmp eq i16 %iv.next, 0
141+
br i1 %exitcond, label %exit, label %loop
142+
143+
exit:
144+
ret void
145+
}
146+
147+
define void @sgt(i16 %a, i16 %b, i1 %c) {
148+
; CHECK-LABEL: 'sgt'
149+
; CHECK-NEXT: Determining loop execution counts for: @sgt
150+
; CHECK-NEXT: Loop %loop: backedge-taken count is %count
151+
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i16 32767
152+
; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is %count
153+
; CHECK-NEXT: Loop %loop: Trip multiple is 1
154+
entry:
155+
br i1 %c, label %b1, label %b2
156+
157+
b1:
158+
%cmp1 = icmp sgt i16 %a, 8
159+
br i1 %cmp1, label %preheader, label %exit
160+
161+
b2:
162+
%cmp2 = icmp sgt i16 %b, 8
163+
br i1 %cmp2, label %preheader, label %exit
164+
165+
preheader:
166+
%count = phi i16 [ %a, %b1 ], [ %b, %b2 ]
167+
br label %loop
168+
169+
loop:
170+
%iv = phi i16 [ %iv.next, %loop ], [ %count, %preheader ]
171+
%iv.next = add i16 %iv, -1
172+
%exitcond = icmp slt i16 %iv.next, 0
173+
br i1 %exitcond, label %exit, label %loop
174+
175+
exit:
176+
ret void
177+
}
178+
179+
180+
define void @mixed(i16 %a, i16 %b, i1 %c) {
181+
; CHECK-LABEL: 'mixed'
182+
; CHECK-NEXT: Determining loop execution counts for: @mixed
183+
; CHECK-NEXT: Loop %loop: backedge-taken count is (-1 + (-1 * %count) + (64 smax (1 + %count)))
184+
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i16 -32704
185+
; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is (-1 + (-1 * %count) + (64 smax (1 + %count)))
186+
; CHECK-NEXT: Loop %loop: Trip multiple is 1
187+
entry:
188+
br i1 %c, label %b1, label %b2
189+
190+
b1:
191+
%cmp1 = icmp slt i16 %a, 8
192+
br i1 %cmp1, label %preheader, label %exit
193+
194+
b2:
195+
%cmp2 = icmp ult i16 %b, 8
196+
br i1 %cmp2, label %preheader, label %exit
197+
198+
preheader:
199+
%count = phi i16 [ %a, %b1 ], [ %b, %b2 ]
200+
br label %loop
201+
202+
loop:
203+
%iv = phi i16 [ %iv.next, %loop ], [ %count, %preheader ]
204+
%iv.next = add i16 %iv, 1
205+
%exitcond = icmp slt i16 %iv.next, 64
206+
br i1 %exitcond, label %loop, label %exit
207+
208+
exit:
209+
ret void
210+
}

0 commit comments

Comments
 (0)