Skip to content

Commit 7b8fd8f

Browse files
[LLVM][SCEV] Look through common vscale multiplicand when simplifying compares. (#141798)
My usecase is simplifying the control flow generated by LoopVectorize when vectorising loops whose tripcount is a function of the runtime vector length. This can be problematic because: * CSE is a pre-LoopVectorize transform and so it's common for an IR function to include several calls to llvm.vscale(). (NOTE: Code generation will typically remove the duplicates) * Pre-LoopVectorize instcombines will rewrite some multiplies as shifts. This leads to a mismatch between VL based maths of the scalar loop and that created for the vector loop, which prevents some obvious simplifications. SCEV does not suffer these issues because it effectively does CSE during construction and shifts are represented as multiplies.
1 parent 3c862b4 commit 7b8fd8f

File tree

4 files changed

+194
-49
lines changed

4 files changed

+194
-49
lines changed

llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ m_scev_PtrToInt(const Op0_t &Op0) {
183183
}
184184

185185
/// Match a binary SCEV.
186-
template <typename SCEVTy, typename Op0_t, typename Op1_t>
186+
template <typename SCEVTy, typename Op0_t, typename Op1_t,
187+
bool Commutable = false>
187188
struct SCEVBinaryExpr_match {
188189
Op0_t Op0;
189190
Op1_t Op1;
@@ -192,15 +193,18 @@ struct SCEVBinaryExpr_match {
192193

193194
bool match(const SCEV *S) const {
194195
auto *E = dyn_cast<SCEVTy>(S);
195-
return E && E->getNumOperands() == 2 && Op0.match(E->getOperand(0)) &&
196-
Op1.match(E->getOperand(1));
196+
return E && E->getNumOperands() == 2 &&
197+
((Op0.match(E->getOperand(0)) && Op1.match(E->getOperand(1))) ||
198+
(Commutable && Op0.match(E->getOperand(1)) &&
199+
Op1.match(E->getOperand(0))));
197200
}
198201
};
199202

200-
template <typename SCEVTy, typename Op0_t, typename Op1_t>
201-
inline SCEVBinaryExpr_match<SCEVTy, Op0_t, Op1_t>
203+
template <typename SCEVTy, typename Op0_t, typename Op1_t,
204+
bool Commutable = false>
205+
inline SCEVBinaryExpr_match<SCEVTy, Op0_t, Op1_t, Commutable>
202206
m_scev_Binary(const Op0_t &Op0, const Op1_t &Op1) {
203-
return SCEVBinaryExpr_match<SCEVTy, Op0_t, Op1_t>(Op0, Op1);
207+
return SCEVBinaryExpr_match<SCEVTy, Op0_t, Op1_t, Commutable>(Op0, Op1);
204208
}
205209

206210
template <typename Op0_t, typename Op1_t>
@@ -215,6 +219,12 @@ m_scev_Mul(const Op0_t &Op0, const Op1_t &Op1) {
215219
return m_scev_Binary<SCEVMulExpr>(Op0, Op1);
216220
}
217221

222+
template <typename Op0_t, typename Op1_t>
223+
inline SCEVBinaryExpr_match<SCEVMulExpr, Op0_t, Op1_t, true>
224+
m_scev_c_Mul(const Op0_t &Op0, const Op1_t &Op1) {
225+
return m_scev_Binary<SCEVMulExpr, Op0_t, Op1_t, true>(Op0, Op1);
226+
}
227+
218228
template <typename Op0_t, typename Op1_t>
219229
inline SCEVBinaryExpr_match<SCEVUDivExpr, Op0_t, Op1_t>
220230
m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) {

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10810,6 +10810,25 @@ bool ScalarEvolution::SimplifyICmpOperands(CmpPredicate &Pred, const SCEV *&LHS,
1081010810
if (Depth >= 3)
1081110811
return false;
1081210812

10813+
const SCEV *NewLHS, *NewRHS;
10814+
if (match(LHS, m_scev_c_Mul(m_SCEV(NewLHS), m_SCEVVScale())) &&
10815+
match(RHS, m_scev_c_Mul(m_SCEV(NewRHS), m_SCEVVScale()))) {
10816+
const SCEVMulExpr *LMul = cast<SCEVMulExpr>(LHS);
10817+
const SCEVMulExpr *RMul = cast<SCEVMulExpr>(RHS);
10818+
10819+
// (X * vscale) pred (Y * vscale) ==> X pred Y
10820+
// when both multiples are NSW.
10821+
// (X * vscale) uicmp/eq/ne (Y * vscale) ==> X uicmp/eq/ne Y
10822+
// when both multiples are NUW.
10823+
if ((LMul->hasNoSignedWrap() && RMul->hasNoSignedWrap()) ||
10824+
(LMul->hasNoUnsignedWrap() && RMul->hasNoUnsignedWrap() &&
10825+
!ICmpInst::isSigned(Pred))) {
10826+
LHS = NewLHS;
10827+
RHS = NewRHS;
10828+
Changed = true;
10829+
}
10830+
}
10831+
1081310832
// Canonicalize a constant to the right side.
1081410833
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
1081510834
// Check for both operands constant.
@@ -10984,7 +11003,7 @@ bool ScalarEvolution::SimplifyICmpOperands(CmpPredicate &Pred, const SCEV *&LHS,
1098411003
// Recursively simplify until we either hit a recursion limit or nothing
1098511004
// changes.
1098611005
if (Changed)
10987-
return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
11006+
(void)SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
1098811007

1098911008
return Changed;
1099011009
}

llvm/test/Transforms/LoopVectorize/AArch64/sve-vscale-based-trip-counts.ll

Lines changed: 21 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ define void @vscale_mul_4(ptr noalias noundef readonly captures(none) %a, ptr no
99
; CHECK-NEXT: [[ENTRY:.*]]:
1010
; CHECK-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64()
1111
; CHECK-NEXT: [[TMP1:%.*]] = shl nuw nsw i64 [[TMP0]], 2
12-
; CHECK-NEXT: [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
13-
; CHECK-NEXT: [[TMP5:%.*]] = mul nuw i64 [[TMP4]], 4
14-
; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[TMP1]], [[TMP5]]
12+
; CHECK-NEXT: [[TMP10:%.*]] = call i64 @llvm.vscale.i64()
13+
; CHECK-NEXT: [[TMP3:%.*]] = mul nuw i64 [[TMP10]], 4
14+
; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[TMP1]], [[TMP3]]
1515
; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[TMP1]], [[N_MOD_VF]]
1616
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <vscale x 4 x float>, ptr [[A]], align 4
1717
; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = load <vscale x 4 x float>, ptr [[B]], align 4
18-
; CHECK-NEXT: [[TMP10:%.*]] = fmul <vscale x 4 x float> [[WIDE_LOAD]], [[WIDE_LOAD1]]
19-
; CHECK-NEXT: store <vscale x 4 x float> [[TMP10]], ptr [[B]], align 4
18+
; CHECK-NEXT: [[TMP4:%.*]] = fmul <vscale x 4 x float> [[WIDE_LOAD]], [[WIDE_LOAD1]]
19+
; CHECK-NEXT: store <vscale x 4 x float> [[TMP4]], ptr [[B]], align 4
2020
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]]
2121
; CHECK-NEXT: br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[FOR_BODY:.*]]
2222
; CHECK: [[FOR_COND_CLEANUP]]:
@@ -121,36 +121,29 @@ define void @vscale_mul_12(ptr noalias noundef readonly captures(none) %a, ptr n
121121
; CHECK-NEXT: [[ENTRY:.*]]:
122122
; CHECK-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64()
123123
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i64 [[TMP0]], 12
124-
; CHECK-NEXT: [[TMP2:%.*]] = shl nuw nsw i64 [[TMP0]], 2
125-
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[MUL1]], [[TMP2]]
126-
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
127-
; CHECK: [[VECTOR_PH]]:
128124
; CHECK-NEXT: [[TMP3:%.*]] = call i64 @llvm.vscale.i64()
129125
; CHECK-NEXT: [[TMP4:%.*]] = mul nuw i64 [[TMP3]], 4
130126
; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[MUL1]], [[TMP4]]
131127
; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[MUL1]], [[N_MOD_VF]]
132128
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
133129
; CHECK: [[VECTOR_BODY]]:
134-
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
130+
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
135131
; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds nuw float, ptr [[A]], i64 [[INDEX]]
136-
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <vscale x 4 x float>, ptr [[TMP7]], align 4
137-
; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds nuw float, ptr [[B]], i64 [[INDEX]]
138-
; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = load <vscale x 4 x float>, ptr [[TMP9]], align 4
139-
; CHECK-NEXT: [[TMP11:%.*]] = fmul <vscale x 4 x float> [[WIDE_LOAD]], [[WIDE_LOAD1]]
140-
; CHECK-NEXT: store <vscale x 4 x float> [[TMP11]], ptr [[TMP9]], align 4
132+
; CHECK-NEXT: [[WIDE_LOAD2:%.*]] = load <vscale x 4 x float>, ptr [[TMP7]], align 4
133+
; CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds nuw float, ptr [[B]], i64 [[INDEX]]
134+
; CHECK-NEXT: [[WIDE_LOAD4:%.*]] = load <vscale x 4 x float>, ptr [[TMP12]], align 4
135+
; CHECK-NEXT: [[TMP25:%.*]] = fmul <vscale x 4 x float> [[WIDE_LOAD2]], [[WIDE_LOAD4]]
136+
; CHECK-NEXT: store <vscale x 4 x float> [[TMP25]], ptr [[TMP12]], align 4
141137
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP4]]
142-
; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
143-
; CHECK-NEXT: br i1 [[TMP12]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
138+
; CHECK-NEXT: [[TMP22:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
139+
; CHECK-NEXT: br i1 [[TMP22]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
144140
; CHECK: [[MIDDLE_BLOCK]]:
145141
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[MUL1]], [[N_VEC]]
146-
; CHECK-NEXT: br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[SCALAR_PH]]
147-
; CHECK: [[SCALAR_PH]]:
148-
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
149-
; CHECK-NEXT: br label %[[FOR_BODY:.*]]
142+
; CHECK-NEXT: br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[FOR_BODY:.*]]
150143
; CHECK: [[FOR_COND_CLEANUP]]:
151144
; CHECK-NEXT: ret void
152145
; CHECK: [[FOR_BODY]]:
153-
; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], %[[FOR_BODY]] ]
146+
; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], %[[FOR_BODY]] ], [ [[N_VEC]], %[[MIDDLE_BLOCK]] ]
154147
; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds nuw float, ptr [[A]], i64 [[INDVARS_IV]]
155148
; CHECK-NEXT: [[TMP13:%.*]] = load float, ptr [[ARRAYIDX]], align 4
156149
; CHECK-NEXT: [[ARRAYIDX4:%.*]] = getelementptr inbounds nuw float, ptr [[B]], i64 [[INDVARS_IV]]
@@ -188,17 +181,13 @@ define void @vscale_mul_31(ptr noalias noundef readonly captures(none) %a, ptr n
188181
; CHECK-NEXT: [[ENTRY:.*]]:
189182
; CHECK-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64()
190183
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i64 [[TMP0]], 31
191-
; CHECK-NEXT: [[TMP2:%.*]] = shl nuw nsw i64 [[TMP0]], 3
192-
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[MUL1]], [[TMP2]]
193-
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
194-
; CHECK: [[VECTOR_PH]]:
195184
; CHECK-NEXT: [[TMP3:%.*]] = call i64 @llvm.vscale.i64()
196185
; CHECK-NEXT: [[TMP4:%.*]] = mul nuw i64 [[TMP3]], 8
197186
; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[MUL1]], [[TMP4]]
198187
; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[MUL1]], [[N_MOD_VF]]
199188
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
200189
; CHECK: [[VECTOR_BODY]]:
201-
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
190+
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
202191
; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds nuw float, ptr [[A]], i64 [[INDEX]]
203192
; CHECK-NEXT: [[TMP9:%.*]] = call i64 @llvm.vscale.i64()
204193
; CHECK-NEXT: [[TMP10:%.*]] = shl nuw i64 [[TMP9]], 2
@@ -220,14 +209,11 @@ define void @vscale_mul_31(ptr noalias noundef readonly captures(none) %a, ptr n
220209
; CHECK-NEXT: br i1 [[TMP22]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
221210
; CHECK: [[MIDDLE_BLOCK]]:
222211
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[MUL1]], [[N_VEC]]
223-
; CHECK-NEXT: br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[SCALAR_PH]]
224-
; CHECK: [[SCALAR_PH]]:
225-
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
226-
; CHECK-NEXT: br label %[[FOR_BODY:.*]]
212+
; CHECK-NEXT: br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[FOR_BODY:.*]]
227213
; CHECK: [[FOR_COND_CLEANUP]]:
228214
; CHECK-NEXT: ret void
229215
; CHECK: [[FOR_BODY]]:
230-
; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], %[[FOR_BODY]] ]
216+
; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], %[[FOR_BODY]] ], [ [[N_VEC]], %[[MIDDLE_BLOCK]] ]
231217
; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds nuw float, ptr [[A]], i64 [[INDVARS_IV]]
232218
; CHECK-NEXT: [[TMP23:%.*]] = load float, ptr [[ARRAYIDX]], align 4
233219
; CHECK-NEXT: [[ARRAYIDX4:%.*]] = getelementptr inbounds nuw float, ptr [[B]], i64 [[INDVARS_IV]]
@@ -265,17 +251,13 @@ define void @vscale_mul_64(ptr noalias noundef readonly captures(none) %a, ptr n
265251
; CHECK-NEXT: [[ENTRY:.*]]:
266252
; CHECK-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64()
267253
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i64 [[TMP0]], 64
268-
; CHECK-NEXT: [[TMP2:%.*]] = shl nuw nsw i64 [[TMP0]], 3
269-
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[MUL1]], [[TMP2]]
270-
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
271-
; CHECK: [[VECTOR_PH]]:
272254
; CHECK-NEXT: [[TMP3:%.*]] = call i64 @llvm.vscale.i64()
273255
; CHECK-NEXT: [[TMP4:%.*]] = mul nuw i64 [[TMP3]], 8
274256
; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[MUL1]], [[TMP4]]
275257
; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[MUL1]], [[N_MOD_VF]]
276258
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
277259
; CHECK: [[VECTOR_BODY]]:
278-
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
260+
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
279261
; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds nuw float, ptr [[A]], i64 [[INDEX]]
280262
; CHECK-NEXT: [[TMP9:%.*]] = call i64 @llvm.vscale.i64()
281263
; CHECK-NEXT: [[TMP10:%.*]] = shl nuw i64 [[TMP9]], 2
@@ -297,14 +279,11 @@ define void @vscale_mul_64(ptr noalias noundef readonly captures(none) %a, ptr n
297279
; CHECK-NEXT: br i1 [[TMP22]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
298280
; CHECK: [[MIDDLE_BLOCK]]:
299281
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[MUL1]], [[N_VEC]]
300-
; CHECK-NEXT: br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[SCALAR_PH]]
301-
; CHECK: [[SCALAR_PH]]:
302-
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
303-
; CHECK-NEXT: br label %[[FOR_BODY:.*]]
282+
; CHECK-NEXT: br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[FOR_BODY:.*]]
304283
; CHECK: [[FOR_COND_CLEANUP]]:
305284
; CHECK-NEXT: ret void
306285
; CHECK: [[FOR_BODY]]:
307-
; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], %[[FOR_BODY]] ]
286+
; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], %[[FOR_BODY]] ], [ [[N_VEC]], %[[MIDDLE_BLOCK]] ]
308287
; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds nuw float, ptr [[A]], i64 [[INDVARS_IV]]
309288
; CHECK-NEXT: [[TMP23:%.*]] = load float, ptr [[ARRAYIDX]], align 4
310289
; CHECK-NEXT: [[ARRAYIDX4:%.*]] = getelementptr inbounds nuw float, ptr [[B]], i64 [[INDVARS_IV]]

llvm/unittests/Analysis/ScalarEvolutionTest.cpp

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1768,4 +1768,141 @@ TEST_F(ScalarEvolutionsTest, ComplexityComparatorIsStrictWeakOrdering3) {
17681768
SE.getSCEV(Or1);
17691769
}
17701770

1771+
TEST_F(ScalarEvolutionsTest, SimplifyICmpOperands) {
1772+
LLVMContext C;
1773+
SMDiagnostic Err;
1774+
std::unique_ptr<Module> M =
1775+
parseAssemblyString("define i32 @foo(ptr %loc, i32 %a, i32 %b) {"
1776+
"entry: "
1777+
" ret i32 %a "
1778+
"} ",
1779+
Err, C);
1780+
1781+
ASSERT_TRUE(M && "Could not parse module?");
1782+
ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
1783+
1784+
// Remove common factor when there's no signed wrapping.
1785+
runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
1786+
const SCEV *A = SE.getSCEV(getArgByName(F, "a"));
1787+
const SCEV *B = SE.getSCEV(getArgByName(F, "b"));
1788+
const SCEV *VS = SE.getVScale(A->getType());
1789+
const SCEV *VSxA = SE.getMulExpr(VS, A, SCEV::FlagNSW);
1790+
const SCEV *VSxB = SE.getMulExpr(VS, B, SCEV::FlagNSW);
1791+
1792+
{
1793+
CmpPredicate NewPred = ICmpInst::ICMP_SLT;
1794+
const SCEV *NewLHS = VSxA;
1795+
const SCEV *NewRHS = VSxB;
1796+
EXPECT_TRUE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
1797+
EXPECT_EQ(NewPred, ICmpInst::ICMP_SLT);
1798+
EXPECT_EQ(NewLHS, A);
1799+
EXPECT_EQ(NewRHS, B);
1800+
}
1801+
1802+
{
1803+
CmpPredicate NewPred = ICmpInst::ICMP_ULT;
1804+
const SCEV *NewLHS = VSxA;
1805+
const SCEV *NewRHS = VSxB;
1806+
EXPECT_TRUE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
1807+
EXPECT_EQ(NewPred, ICmpInst::ICMP_ULT);
1808+
EXPECT_EQ(NewLHS, A);
1809+
EXPECT_EQ(NewRHS, B);
1810+
}
1811+
1812+
{
1813+
CmpPredicate NewPred = ICmpInst::ICMP_EQ;
1814+
const SCEV *NewLHS = VSxA;
1815+
const SCEV *NewRHS = VSxB;
1816+
EXPECT_TRUE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
1817+
EXPECT_EQ(NewPred, ICmpInst::ICMP_EQ);
1818+
EXPECT_EQ(NewLHS, A);
1819+
EXPECT_EQ(NewRHS, B);
1820+
}
1821+
1822+
// Verify the common factor's position doesn't impede simplification.
1823+
{
1824+
const SCEV *C = SE.getConstant(A->getType(), 100);
1825+
const SCEV *CxVS = SE.getMulExpr(C, VS, SCEV::FlagNSW);
1826+
1827+
// Verify common factor is available at different indices.
1828+
ASSERT_TRUE(isa<SCEVVScale>(cast<SCEVMulExpr>(VSxA)->getOperand(0)) !=
1829+
isa<SCEVVScale>(cast<SCEVMulExpr>(CxVS)->getOperand(0)));
1830+
1831+
CmpPredicate NewPred = ICmpInst::ICMP_SLT;
1832+
const SCEV *NewLHS = VSxA;
1833+
const SCEV *NewRHS = CxVS;
1834+
EXPECT_TRUE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
1835+
EXPECT_EQ(NewPred, ICmpInst::ICMP_SLT);
1836+
EXPECT_EQ(NewLHS, A);
1837+
EXPECT_EQ(NewRHS, C);
1838+
}
1839+
});
1840+
1841+
// Remove common factor when there's no unsigned wrapping.
1842+
runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
1843+
const SCEV *A = SE.getSCEV(getArgByName(F, "a"));
1844+
const SCEV *B = SE.getSCEV(getArgByName(F, "b"));
1845+
const SCEV *VS = SE.getVScale(A->getType());
1846+
const SCEV *VSxA = SE.getMulExpr(VS, A, SCEV::FlagNUW);
1847+
const SCEV *VSxB = SE.getMulExpr(VS, B, SCEV::FlagNUW);
1848+
1849+
{
1850+
CmpPredicate NewPred = ICmpInst::ICMP_SLT;
1851+
const SCEV *NewLHS = VSxA;
1852+
const SCEV *NewRHS = VSxB;
1853+
EXPECT_FALSE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
1854+
}
1855+
1856+
{
1857+
CmpPredicate NewPred = ICmpInst::ICMP_ULT;
1858+
const SCEV *NewLHS = VSxA;
1859+
const SCEV *NewRHS = VSxB;
1860+
EXPECT_TRUE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
1861+
EXPECT_EQ(NewPred, ICmpInst::ICMP_ULT);
1862+
EXPECT_EQ(NewLHS, A);
1863+
EXPECT_EQ(NewRHS, B);
1864+
}
1865+
1866+
{
1867+
CmpPredicate NewPred = ICmpInst::ICMP_EQ;
1868+
const SCEV *NewLHS = VSxA;
1869+
const SCEV *NewRHS = VSxB;
1870+
EXPECT_TRUE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
1871+
EXPECT_EQ(NewPred, ICmpInst::ICMP_EQ);
1872+
EXPECT_EQ(NewLHS, A);
1873+
EXPECT_EQ(NewRHS, B);
1874+
}
1875+
});
1876+
1877+
// Do not remove common factor due to wrap flag mismatch.
1878+
runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
1879+
const SCEV *A = SE.getSCEV(getArgByName(F, "a"));
1880+
const SCEV *B = SE.getSCEV(getArgByName(F, "b"));
1881+
const SCEV *VS = SE.getVScale(A->getType());
1882+
const SCEV *VSxA = SE.getMulExpr(VS, A, SCEV::FlagNSW);
1883+
const SCEV *VSxB = SE.getMulExpr(VS, B, SCEV::FlagNUW);
1884+
1885+
{
1886+
CmpPredicate NewPred = ICmpInst::ICMP_SLT;
1887+
const SCEV *NewLHS = VSxA;
1888+
const SCEV *NewRHS = VSxB;
1889+
EXPECT_FALSE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
1890+
}
1891+
1892+
{
1893+
CmpPredicate NewPred = ICmpInst::ICMP_ULT;
1894+
const SCEV *NewLHS = VSxA;
1895+
const SCEV *NewRHS = VSxB;
1896+
EXPECT_FALSE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
1897+
}
1898+
1899+
{
1900+
CmpPredicate NewPred = ICmpInst::ICMP_EQ;
1901+
const SCEV *NewLHS = VSxA;
1902+
const SCEV *NewRHS = VSxB;
1903+
EXPECT_FALSE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
1904+
}
1905+
});
1906+
}
1907+
17711908
} // end namespace llvm

0 commit comments

Comments
 (0)