Skip to content

Commit 139a286

Browse files
[LSR] Make OptimizeLoopTermCond able to handle some non-cmp conditions
Currently OptimizeLoopTermCond can only convert a cmp instruction to using a postincrement induction variable, which means it can't handle predicated loops where the termination condition comes from get_active_lane_mask. Relax this restriction so that we can handle any kind of instruction, though only if it's the instruction immediately before the branch (except for possibly an extractelement).
1 parent dda95d9 commit 139a286

File tree

2 files changed

+241
-12
lines changed

2 files changed

+241
-12
lines changed

llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2181,8 +2181,8 @@ class LSRInstance {
21812181
SmallSetVector<Instruction *, 4> InsertedNonLCSSAInsts;
21822182

21832183
void OptimizeShadowIV();
2184-
bool FindIVUserForCond(ICmpInst *Cond, IVStrideUse *&CondUse);
2185-
ICmpInst *OptimizeMax(ICmpInst *Cond, IVStrideUse* &CondUse);
2184+
bool FindIVUserForCond(Instruction *Cond, IVStrideUse *&CondUse);
2185+
Instruction *OptimizeMax(ICmpInst *Cond, IVStrideUse* &CondUse);
21862186
void OptimizeLoopTermCond();
21872187

21882188
void ChainInstruction(Instruction *UserInst, Instruction *IVOper,
@@ -2416,7 +2416,7 @@ void LSRInstance::OptimizeShadowIV() {
24162416

24172417
/// If Cond has an operand that is an expression of an IV, set the IV user and
24182418
/// stride information and return true, otherwise return false.
2419-
bool LSRInstance::FindIVUserForCond(ICmpInst *Cond, IVStrideUse *&CondUse) {
2419+
bool LSRInstance::FindIVUserForCond(Instruction *Cond, IVStrideUse *&CondUse) {
24202420
for (IVStrideUse &U : IU)
24212421
if (U.getUser() == Cond) {
24222422
// NOTE: we could handle setcc instructions with multiple uses here, but
@@ -2476,7 +2476,7 @@ bool LSRInstance::FindIVUserForCond(ICmpInst *Cond, IVStrideUse *&CondUse) {
24762476
/// This function solves this problem by detecting this type of loop and
24772477
/// rewriting their conditions from ICMP_NE back to ICMP_SLT, and deleting
24782478
/// the instructions for the maximum computation.
2479-
ICmpInst *LSRInstance::OptimizeMax(ICmpInst *Cond, IVStrideUse* &CondUse) {
2479+
Instruction *LSRInstance::OptimizeMax(ICmpInst *Cond, IVStrideUse* &CondUse) {
24802480
// Check that the loop matches the pattern we're looking for.
24812481
if (Cond->getPredicate() != CmpInst::ICMP_EQ &&
24822482
Cond->getPredicate() != CmpInst::ICMP_NE)
@@ -2620,15 +2620,34 @@ LSRInstance::OptimizeLoopTermCond() {
26202620
// one register value.
26212621

26222622
BranchInst *TermBr = dyn_cast<BranchInst>(ExitingBlock->getTerminator());
2623-
if (!TermBr)
2623+
if (!TermBr || TermBr->isUnconditional())
26242624
continue;
2625-
// FIXME: Overly conservative, termination condition could be an 'or' etc..
2626-
if (TermBr->isUnconditional() || !isa<ICmpInst>(TermBr->getCondition()))
2625+
2626+
Instruction *Cond = dyn_cast<Instruction>(TermBr->getCondition());
2627+
bool CondImmediatelyBeforeTerm = Cond && Cond->getNextNode() == TermBr;
2628+
// If the argument to TermBr is an extractelement, then the source of that
2629+
// instruction is what's generated the condition.
2630+
auto *Extract = dyn_cast_or_null<ExtractElementInst>(Cond);
2631+
if (Extract) {
2632+
Cond = dyn_cast<Instruction>(Extract->getVectorOperand());
2633+
if (Cond && CondImmediatelyBeforeTerm)
2634+
CondImmediatelyBeforeTerm = Cond->getNextNode() == Extract;
2635+
}
2636+
// FIXME: We could do more here, like handling logical operations where one
2637+
// side is a cmp that uses an induction variable.
2638+
if (!Cond)
2639+
continue;
2640+
2641+
// If the condition instruction isn't immediately before TermBr then it has
2642+
// to either be a CmpInst, or be immediately before an extract that's
2643+
// immediately before TermBr, as currently we can only move or clone a
2644+
// CmpInst.
2645+
// FIXME: We should be able to do this when it's safe to do so.
2646+
if ((!isa<CmpInst>(Cond) || Extract) && !CondImmediatelyBeforeTerm)
26272647
continue;
26282648

26292649
// Search IVUsesByStride to find Cond's IVUse if there is one.
26302650
IVStrideUse *CondUse = nullptr;
2631-
ICmpInst *Cond = cast<ICmpInst>(TermBr->getCondition());
26322651
if (!FindIVUserForCond(Cond, CondUse))
26332652
continue;
26342653

@@ -2638,7 +2657,8 @@ LSRInstance::OptimizeLoopTermCond() {
26382657
// One consequence of doing this now is that it disrupts the count-down
26392658
// optimization. That's not always a bad thing though, because in such
26402659
// cases it may still be worthwhile to avoid a max.
2641-
Cond = OptimizeMax(Cond, CondUse);
2660+
if (auto *Cmp = dyn_cast<ICmpInst>(Cond))
2661+
Cond = OptimizeMax(Cmp, CondUse);
26422662

26432663
// If this exiting block dominates the latch block, it may also use
26442664
// the post-inc value if it won't be shared with other uses.
@@ -2703,13 +2723,13 @@ LSRInstance::OptimizeLoopTermCond() {
27032723
// It's possible for the setcc instruction to be anywhere in the loop, and
27042724
// possible for it to have multiple users. If it is not immediately before
27052725
// the exiting block branch, move it.
2706-
if (Cond->getNextNode() != TermBr) {
2726+
if (!CondImmediatelyBeforeTerm) {
27072727
if (Cond->hasOneUse()) {
27082728
Cond->moveBefore(TermBr->getIterator());
27092729
} else {
27102730
// Clone the terminating condition and insert into the loopend.
2711-
ICmpInst *OldCond = Cond;
2712-
Cond = cast<ICmpInst>(Cond->clone());
2731+
Instruction *OldCond = Cond;
2732+
Cond = Cond->clone();
27132733
Cond->setName(L->getHeader()->getName() + ".termcond");
27142734
Cond->insertInto(ExitingBlock, TermBr->getIterator());
27152735

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
2+
; RUN: opt -loop-reduce %s -S -o - | FileCheck %s
3+
4+
target triple = "aarch64-unknown-linux-gnu"
5+
6+
; Tests where the loop termination condition is not generated by a compare.
7+
8+
; The call to get.active.lane.mask in the loop should use the postincrement
9+
; value of %index.
10+
define void @lane_mask(ptr %dst, i64 %n) #0 {
11+
; CHECK-LABEL: define void @lane_mask(
12+
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) #[[ATTR0:[0-9]+]] {
13+
; CHECK-NEXT: [[ENTRY:.*]]:
14+
; CHECK-NEXT: [[VSCALE:%.*]] = tail call i64 @llvm.vscale.i64()
15+
; CHECK-NEXT: [[VSCALEX4:%.*]] = shl i64 [[VSCALE]], 2
16+
; CHECK-NEXT: [[ACTIVE_LANE_MASK_ENTRY:%.*]] = tail call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 0, i64 [[N]])
17+
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
18+
; CHECK: [[VECTOR_BODY]]:
19+
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[TMP1:%.*]], %[[VECTOR_BODY]] ]
20+
; CHECK-NEXT: [[ACTIVE_LANE_MASK:%.*]] = phi <vscale x 4 x i1> [ [[ACTIVE_LANE_MASK_ENTRY]], %[[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], %[[VECTOR_BODY]] ]
21+
; CHECK-NEXT: [[TMP0:%.*]] = shl i64 [[INDEX]], 2
22+
; CHECK-NEXT: [[SCEVGEP:%.*]] = getelementptr i8, ptr [[DST]], i64 [[TMP0]]
23+
; CHECK-NEXT: tail call void @llvm.masked.store.nxv4i32.p0(<vscale x 4 x i32> splat (i32 1), ptr align 4 [[SCEVGEP]], <vscale x 4 x i1> [[ACTIVE_LANE_MASK]])
24+
; CHECK-NEXT: [[TMP1]] = add i64 [[INDEX]], [[VSCALEX4]]
25+
; CHECK-NEXT: [[ACTIVE_LANE_MASK_NEXT]] = tail call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 [[TMP1]], i64 [[N]])
26+
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <vscale x 4 x i1> [[ACTIVE_LANE_MASK_NEXT]], i64 0
27+
; CHECK-NEXT: br i1 [[TMP2]], label %[[VECTOR_BODY]], label %[[FOR_COND_CLEANUP:.*]]
28+
; CHECK: [[FOR_COND_CLEANUP]]:
29+
; CHECK-NEXT: ret void
30+
;
31+
entry:
32+
%vscale = tail call i64 @llvm.vscale.i64()
33+
%vscalex4 = shl i64 %vscale, 2
34+
%active.lane.mask.entry = tail call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 0, i64 %n)
35+
br label %vector.body
36+
37+
vector.body:
38+
%index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
39+
%active.lane.mask = phi <vscale x 4 x i1> [ %active.lane.mask.entry, %entry ], [ %active.lane.mask.next, %vector.body ]
40+
%gep = getelementptr inbounds nuw i32, ptr %dst, i64 %index
41+
tail call void @llvm.masked.store.nxv4i32.p0(<vscale x 4 x i32> splat (i32 1), ptr %gep, i32 4, <vscale x 4 x i1> %active.lane.mask)
42+
%index.next = add i64 %index, %vscalex4
43+
%active.lane.mask.next = tail call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 %index.next, i64 %n)
44+
%cond = extractelement <vscale x 4 x i1> %active.lane.mask.next, i64 0
45+
br i1 %cond, label %vector.body, label %for.cond.cleanup
46+
47+
for.cond.cleanup:
48+
ret void
49+
}
50+
51+
; The store between the call and the branch should cause get.active.lane.mask to
52+
; use a preincrement value.
53+
; FIXME: We could use a postincrement value by moving the call and
54+
; extractelement to after the store.
55+
define void @lane_mask_not_last(ptr %dst, i64 %n) #0 {
56+
; CHECK-LABEL: define void @lane_mask_not_last(
57+
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) #[[ATTR0]] {
58+
; CHECK-NEXT: [[ENTRY:.*]]:
59+
; CHECK-NEXT: [[VSCALE:%.*]] = tail call i64 @llvm.vscale.i64()
60+
; CHECK-NEXT: [[VSCALEX4:%.*]] = shl i64 [[VSCALE]], 2
61+
; CHECK-NEXT: [[ACTIVE_LANE_MASK_ENTRY:%.*]] = tail call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 0, i64 [[N]])
62+
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
63+
; CHECK: [[VECTOR_BODY]]:
64+
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[INDEX_NEXT1:%.*]], %[[VECTOR_BODY]] ]
65+
; CHECK-NEXT: [[ACTIVE_LANE_MASK:%.*]] = phi <vscale x 4 x i1> [ [[ACTIVE_LANE_MASK_ENTRY]], %[[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], %[[VECTOR_BODY]] ]
66+
; CHECK-NEXT: [[INDEX_NEXT1]] = add i64 [[INDEX]], [[VSCALEX4]]
67+
; CHECK-NEXT: [[INDEX_NEXT:%.*]] = add i64 [[VSCALEX4]], [[INDEX]]
68+
; CHECK-NEXT: [[TMP0:%.*]] = shl i64 [[INDEX]], 2
69+
; CHECK-NEXT: [[SCEVGEP:%.*]] = getelementptr i8, ptr [[DST]], i64 [[TMP0]]
70+
; CHECK-NEXT: [[ACTIVE_LANE_MASK_NEXT]] = tail call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 [[INDEX_NEXT]], i64 [[N]])
71+
; CHECK-NEXT: tail call void @llvm.masked.store.nxv4i32.p0(<vscale x 4 x i32> splat (i32 1), ptr align 4 [[SCEVGEP]], <vscale x 4 x i1> [[ACTIVE_LANE_MASK]])
72+
; CHECK-NEXT: [[TMP1:%.*]] = extractelement <vscale x 4 x i1> [[ACTIVE_LANE_MASK_NEXT]], i64 0
73+
; CHECK-NEXT: br i1 [[TMP1]], label %[[VECTOR_BODY]], label %[[FOR_COND_CLEANUP:.*]]
74+
; CHECK: [[FOR_COND_CLEANUP]]:
75+
; CHECK-NEXT: ret void
76+
;
77+
entry:
78+
%vscale = tail call i64 @llvm.vscale.i64()
79+
%vscalex4 = shl i64 %vscale, 2
80+
%active.lane.mask.entry = tail call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 0, i64 %n)
81+
br label %vector.body
82+
83+
vector.body:
84+
%index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
85+
%active.lane.mask = phi <vscale x 4 x i1> [ %active.lane.mask.entry, %entry ], [ %active.lane.mask.next, %vector.body ]
86+
%gep = getelementptr inbounds nuw i32, ptr %dst, i64 %index
87+
%index.next = add i64 %index, %vscalex4
88+
%active.lane.mask.next = tail call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 %index.next, i64 %n)
89+
tail call void @llvm.masked.store.nxv4i32.p0(<vscale x 4 x i32> splat (i32 1), ptr %gep, i32 4, <vscale x 4 x i1> %active.lane.mask)
90+
%cond = extractelement <vscale x 4 x i1> %active.lane.mask.next, i64 0
91+
br i1 %cond, label %vector.body, label %for.cond.cleanup
92+
93+
for.cond.cleanup:
94+
ret void
95+
}
96+
97+
; The call to cmp_fn in the loop should use the postincrement value of %index.
98+
define void @uses_cmp_fn(ptr %dst, i64 %n) {
99+
; CHECK-LABEL: define void @uses_cmp_fn(
100+
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) {
101+
; CHECK-NEXT: [[ENTRY:.*]]:
102+
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
103+
; CHECK: [[VECTOR_BODY]]:
104+
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
105+
; CHECK-NEXT: [[TMP0:%.*]] = shl i64 [[INDEX]], 2
106+
; CHECK-NEXT: [[SCEVGEP:%.*]] = getelementptr i8, ptr [[DST]], i64 [[TMP0]]
107+
; CHECK-NEXT: store i32 0, ptr [[SCEVGEP]], align 4
108+
; CHECK-NEXT: [[INDEX_NEXT]] = add i64 [[INDEX]], 1
109+
; CHECK-NEXT: [[COND:%.*]] = tail call i1 @cmp_fn(i64 [[INDEX_NEXT]])
110+
; CHECK-NEXT: br i1 [[COND]], label %[[VECTOR_BODY]], label %[[FOR_COND_CLEANUP:.*]]
111+
; CHECK: [[FOR_COND_CLEANUP]]:
112+
; CHECK-NEXT: ret void
113+
;
114+
entry:
115+
br label %vector.body
116+
117+
vector.body:
118+
%index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
119+
%gep = getelementptr inbounds nuw i32, ptr %dst, i64 %index
120+
store i32 0, ptr %gep, align 4
121+
%index.next = add i64 %index, 1
122+
%cond = tail call i1 @cmp_fn(i64 %index.next)
123+
br i1 %cond, label %vector.body, label %for.cond.cleanup
124+
125+
for.cond.cleanup:
126+
ret void
127+
}
128+
129+
; The store between the call and the branch should cause cmp_fn to use a
130+
; preincrement value. We can't move the call after the store as the call could
131+
; have side effects.
132+
define void @uses_cmp_fn_not_last(ptr %dst, i64 %n) {
133+
; CHECK-LABEL: define void @uses_cmp_fn_not_last(
134+
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) {
135+
; CHECK-NEXT: [[ENTRY:.*]]:
136+
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
137+
; CHECK: [[VECTOR_BODY]]:
138+
; CHECK-NEXT: [[LSR_IV1:%.*]] = phi ptr [ [[SCEVGEP:%.*]], %[[VECTOR_BODY]] ], [ [[DST]], %[[ENTRY]] ]
139+
; CHECK-NEXT: [[LSR_IV:%.*]] = phi i64 [ [[LSR_IV_NEXT:%.*]], %[[VECTOR_BODY]] ], [ 1, %[[ENTRY]] ]
140+
; CHECK-NEXT: [[COND:%.*]] = tail call i1 @cmp_fn(i64 [[LSR_IV]])
141+
; CHECK-NEXT: store i32 0, ptr [[LSR_IV1]], align 4
142+
; CHECK-NEXT: [[LSR_IV_NEXT]] = add i64 [[LSR_IV]], 1
143+
; CHECK-NEXT: [[SCEVGEP]] = getelementptr i8, ptr [[LSR_IV1]], i64 4
144+
; CHECK-NEXT: br i1 [[COND]], label %[[VECTOR_BODY]], label %[[FOR_COND_CLEANUP:.*]]
145+
; CHECK: [[FOR_COND_CLEANUP]]:
146+
; CHECK-NEXT: ret void
147+
;
148+
entry:
149+
br label %vector.body
150+
151+
vector.body:
152+
%index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
153+
%gep = getelementptr inbounds nuw i32, ptr %dst, i64 %index
154+
%index.next = add i64 %index, 1
155+
%cond = tail call i1 @cmp_fn(i64 %index.next)
156+
store i32 0, ptr %gep, align 4
157+
br i1 %cond, label %vector.body, label %for.cond.cleanup
158+
159+
for.cond.cleanup:
160+
ret void
161+
}
162+
163+
; cmp2 will use a preincrement induction variable as it isn't directly the loop
164+
; termination condition.
165+
; FIXME: We could potentially handle this by examining the operands of the 'and'
166+
; instruction.
167+
define void @cmp_and(ptr %dst, i64 %n) {
168+
; CHECK-LABEL: define void @cmp_and(
169+
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) {
170+
; CHECK-NEXT: [[ENTRY:.*]]:
171+
; CHECK-NEXT: [[TMP0:%.*]] = add i64 [[N]], -1
172+
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
173+
; CHECK: [[VECTOR_BODY]]:
174+
; CHECK-NEXT: [[LSR_IV1:%.*]] = phi ptr [ [[SCEVGEP:%.*]], %[[VECTOR_BODY]] ], [ [[DST]], %[[ENTRY]] ]
175+
; CHECK-NEXT: [[LSR_IV_NEXT:%.*]] = phi i64 [ [[LSR_IV_NEXT1:%.*]], %[[VECTOR_BODY]] ], [ [[TMP0]], %[[ENTRY]] ]
176+
; CHECK-NEXT: [[VAL:%.*]] = load i64, ptr [[LSR_IV1]], align 8
177+
; CHECK-NEXT: [[CMP1:%.*]] = icmp ne i64 [[VAL]], [[N]]
178+
; CHECK-NEXT: [[CMP2:%.*]] = icmp ne i64 [[LSR_IV_NEXT]], 0
179+
; CHECK-NEXT: [[COND:%.*]] = and i1 [[CMP1]], [[CMP2]]
180+
; CHECK-NEXT: [[LSR_IV_NEXT1]] = add i64 [[LSR_IV_NEXT]], -1
181+
; CHECK-NEXT: [[SCEVGEP]] = getelementptr i8, ptr [[LSR_IV1]], i64 4
182+
; CHECK-NEXT: br i1 [[COND]], label %[[VECTOR_BODY]], label %[[FOR_COND_CLEANUP:.*]]
183+
; CHECK: [[FOR_COND_CLEANUP]]:
184+
; CHECK-NEXT: ret void
185+
;
186+
entry:
187+
br label %vector.body
188+
189+
vector.body:
190+
%index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
191+
%gep = getelementptr inbounds nuw i32, ptr %dst, i64 %index
192+
%val = load i64, ptr %gep, align 8
193+
%index.next = add i64 %index, 1
194+
%cmp1 = icmp ne i64 %val, %n
195+
%cmp2 = icmp ne i64 %index.next, %n
196+
%cond = and i1 %cmp1, %cmp2
197+
br i1 %cond, label %vector.body, label %for.cond.cleanup
198+
199+
for.cond.cleanup:
200+
ret void
201+
}
202+
203+
204+
declare i64 @llvm.vscale.i64()
205+
declare <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64, i64)
206+
declare void @llvm.masked.store.nxv4i32.p0(<vscale x 4 x i32>, ptr captures(none), i32 immarg, <vscale x 4 x i1>)
207+
declare i1 @cmp_fn(i64)
208+
209+
attributes #0 = { "target-features"="+sve2" }

0 commit comments

Comments
 (0)