Skip to content

Commit cd42158

Browse files
[SimpleLoopUnswitch] Record loops from unswitching non-trivial conditions
Track newly-cloned loops coming from unswitching non-trivial invariant conditions, so as to prevent conditions in such cloned blocks from being unswitched again. While this should optimistically suffice, ensure the outer loop basic block size is taken into account as well when estimating the cost for unswitching non-trivial conditions. Fixes: #138509.
1 parent eaf911b commit cd42158

File tree

2 files changed

+84
-26
lines changed

2 files changed

+84
-26
lines changed

llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2142,34 +2142,33 @@ void visitDomSubTree(DominatorTree &DT, BasicBlock *BB, CallableT Callable) {
21422142
void postUnswitch(Loop &L, LPMUpdater &U, StringRef LoopName,
21432143
bool CurrentLoopValid, bool PartiallyInvariant,
21442144
bool InjectedCondition, ArrayRef<Loop *> NewLoops) {
2145-
// If we did a non-trivial unswitch, we have added new (cloned) loops.
2146-
if (!NewLoops.empty())
2145+
auto RecordLoopAsUnswitched = [&](Loop *TargetLoop, StringRef Tag) {
2146+
auto &Ctx = TargetLoop->getHeader()->getContext();
2147+
const auto &DisableMDName = (Twine(Tag) + ".disable").str();
2148+
MDNode *DisableMD = MDNode::get(Ctx, MDString::get(Ctx, DisableMDName));
2149+
MDNode *NewLoopID = makePostTransformationMetadata(
2150+
Ctx, TargetLoop->getLoopID(), {Tag}, {DisableMD});
2151+
TargetLoop->setLoopID(NewLoopID);
2152+
};
2153+
2154+
// If we performed a non-trivial unswitch, we have added new cloned loops.
2155+
// Mark such newly-created loops as visited.
2156+
if (!NewLoops.empty()) {
2157+
for (Loop *NL : NewLoops)
2158+
RecordLoopAsUnswitched(NL, "llvm.loop.unswitch.nontrivial");
21472159
U.addSiblingLoops(NewLoops);
2160+
}
21482161

21492162
// If the current loop remains valid, we should revisit it to catch any
21502163
// other unswitch opportunities. Otherwise, we need to mark it as deleted.
21512164
if (CurrentLoopValid) {
21522165
if (PartiallyInvariant) {
21532166
// Mark the new loop as partially unswitched, to avoid unswitching on
21542167
// the same condition again.
2155-
auto &Context = L.getHeader()->getContext();
2156-
MDNode *DisableUnswitchMD = MDNode::get(
2157-
Context,
2158-
MDString::get(Context, "llvm.loop.unswitch.partial.disable"));
2159-
MDNode *NewLoopID = makePostTransformationMetadata(
2160-
Context, L.getLoopID(), {"llvm.loop.unswitch.partial"},
2161-
{DisableUnswitchMD});
2162-
L.setLoopID(NewLoopID);
2168+
RecordLoopAsUnswitched(&L, "llvm.loop.unswitch.partial");
21632169
} else if (InjectedCondition) {
21642170
// Do the same for injection of invariant conditions.
2165-
auto &Context = L.getHeader()->getContext();
2166-
MDNode *DisableUnswitchMD = MDNode::get(
2167-
Context,
2168-
MDString::get(Context, "llvm.loop.unswitch.injection.disable"));
2169-
MDNode *NewLoopID = makePostTransformationMetadata(
2170-
Context, L.getLoopID(), {"llvm.loop.unswitch.injection"},
2171-
{DisableUnswitchMD});
2172-
L.setLoopID(NewLoopID);
2171+
RecordLoopAsUnswitched(&L, "llvm.loop.unswitch.injection");
21732172
} else
21742173
U.revisitCurrentLoop();
21752174
} else
@@ -2806,9 +2805,9 @@ static BranchInst *turnGuardIntoBranch(IntrinsicInst *GI, Loop &L,
28062805
}
28072806

28082807
/// Cost multiplier is a way to limit potentially exponential behavior
2809-
/// of loop-unswitch. Cost is multipied in proportion of 2^number of unswitch
2810-
/// candidates available. Also accounting for the number of "sibling" loops with
2811-
/// the idea to account for previous unswitches that already happened on this
2808+
/// of loop-unswitch. Cost is multiplied in proportion of 2^number of unswitch
2809+
/// candidates available. Also consider the number of "sibling" loops with
2810+
/// the idea of accounting for previous unswitches that already happened on this
28122811
/// cluster of loops. There was an attempt to keep this formula simple,
28132812
/// just enough to limit the worst case behavior. Even if it is not that simple
28142813
/// now it is still not an attempt to provide a detailed heuristic size
@@ -2839,7 +2838,14 @@ static int CalculateUnswitchCostMultiplier(
28392838
return 1;
28402839
}
28412840

2841+
// When dealing with nested loops, the basic block size of the outer loop may
2842+
// increase significantly during unswitching non-trivial conditions. The final
2843+
// cost may be adjusted taking this into account.
28422844
auto *ParentL = L.getParentLoop();
2845+
int ParentSizeMultiplier = 1;
2846+
if (ParentL)
2847+
ParentSizeMultiplier = std::max((int)ParentL->getNumBlocks(), 1);
2848+
28432849
int SiblingsCount = (ParentL ? ParentL->getSubLoopsVector().size()
28442850
: std::distance(LI.begin(), LI.end()));
28452851
// Count amount of clones that all the candidates might cause during
@@ -2887,11 +2893,13 @@ static int CalculateUnswitchCostMultiplier(
28872893
SiblingsMultiplier > UnswitchThreshold)
28882894
CostMultiplier = UnswitchThreshold;
28892895
else
2890-
CostMultiplier = std::min(SiblingsMultiplier * (1 << ClonesPower),
2891-
(int)UnswitchThreshold);
2896+
CostMultiplier =
2897+
std::min(SiblingsMultiplier * ParentSizeMultiplier * (1 << ClonesPower),
2898+
(int)UnswitchThreshold);
28922899

28932900
LLVM_DEBUG(dbgs() << " Computed multiplier " << CostMultiplier
2894-
<< " (siblings " << SiblingsMultiplier << " * clones "
2901+
<< " (siblings " << SiblingsMultiplier << "* parent size "
2902+
<< ParentSizeMultiplier << " * clones "
28952903
<< (1 << ClonesPower) << ")"
28962904
<< " for unswitch candidate: " << TI << "\n");
28972905
return CostMultiplier;
@@ -3504,8 +3512,9 @@ static bool unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI,
35043512
SmallVector<NonTrivialUnswitchCandidate, 4> UnswitchCandidates;
35053513
IVConditionInfo PartialIVInfo;
35063514
Instruction *PartialIVCondBranch = nullptr;
3507-
collectUnswitchCandidates(UnswitchCandidates, PartialIVInfo,
3508-
PartialIVCondBranch, L, LI, AA, MSSAU);
3515+
if (!findOptionMDForLoop(&L, "llvm.loop.unswitch.nontrivial.disable"))
3516+
collectUnswitchCandidates(UnswitchCandidates, PartialIVInfo,
3517+
PartialIVCondBranch, L, LI, AA, MSSAU);
35093518
if (!findOptionMDForLoop(&L, "llvm.loop.unswitch.injection.disable"))
35103519
collectUnswitchCandidatesWithInjections(UnswitchCandidates, PartialIVInfo,
35113520
PartialIVCondBranch, L, DT, LI, AA,
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -S -passes="loop-mssa(loop-simplifycfg,licm,loop-rotate,simple-loop-unswitch<nontrivial>)" < %s | FileCheck %s
3+
4+
@a = global i32 0, align 4
5+
@b = global i32 0, align 4
6+
@c = global i32 0, align 4
7+
@d = global i32 0, align 4
8+
9+
define i32 @main() {
10+
entry:
11+
br label %outer.loop.header
12+
13+
outer.loop.header: ; preds = %outer.loop.latch, %entry
14+
br i1 false, label %exit, label %outer.loop.body
15+
16+
outer.loop.body: ; preds = %inner.loop.header, %outer.loop.header
17+
store i32 1, ptr @c, align 4
18+
%cmp = icmp sgt i32 0, -1
19+
br i1 %cmp, label %outer.loop.latch, label %exit
20+
21+
inner.loop.header: ; preds = %outer.loop.latch, %inner.loop.body
22+
%a_val = load i32, ptr @a, align 4
23+
%c_val = load i32, ptr @c, align 4
24+
%mul = mul nsw i32 %c_val, %a_val
25+
store i32 %mul, ptr @b, align 4
26+
%cmp2 = icmp sgt i32 %mul, -1
27+
br i1 %cmp2, label %inner.loop.body, label %outer.loop.body
28+
29+
inner.loop.body: ; preds = %inner.loop.header
30+
%mul2 = mul nsw i32 %c_val, 3
31+
store i32 %mul2, ptr @c, align 4
32+
store i32 %c_val, ptr @d, align 4
33+
%mul3 = mul nsw i32 %c_val, %a_val
34+
%cmp3 = icmp sgt i32 %mul3, -1
35+
br i1 %cmp3, label %inner.loop.header, label %exit
36+
37+
outer.loop.latch: ; preds = %outer.loop.body
38+
%d_val = load i32, ptr @d, align 4
39+
store i32 %d_val, ptr @b, align 4
40+
%cmp4 = icmp eq i32 %d_val, 0
41+
br i1 %cmp4, label %inner.loop.header, label %outer.loop.header
42+
43+
exit: ; preds = %inner.loop.body, %outer.loop.body, %outer.loop.header
44+
ret i32 0
45+
}
46+
47+
; CHECK: [[LOOP0:.*]] = distinct !{[[LOOP0]], [[META1:![0-9]+]]}
48+
; CHECK: [[META1]] = !{!"llvm.loop.unswitch.nontrivial.disable"}
49+
; CHECK: [[LOOP2:.*]] = distinct !{[[LOOP2]], [[META1]]}

0 commit comments

Comments
 (0)