Skip to content

Commit 0dd5c04

Browse files
[LLVM][SCEV] Look through common vscale multiplicand when simplifying compares.
My usecase is simplifying the control flow generated by LoopVectorize when vectorising loops whose tripcount is a function of the runtime vector length. This can be problematic because: * CSE is a pre-LoopVectorize transform and so it's common for an IR function to include several calls to llvm.vscale(). (NOTE: Code generation will typically remove the duplicates) * Pre-LoopVectorize instcombines will rewrite some multiplies as shifts. This leads to a mismatch between VL based maths of the scalar loop and that created for the vector loop, which prevents some obvious simplifications. SCEV does not suffer these issues because it effectively does CSE during construction and shifts are represented as multiplies.
1 parent 8f37668 commit 0dd5c04

File tree

4 files changed

+194
-49
lines changed

4 files changed

+194
-49
lines changed

llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ m_scev_PtrToInt(const Op0_t &Op0) {
183183
}
184184

185185
/// Match a binary SCEV.
186-
template <typename SCEVTy, typename Op0_t, typename Op1_t>
186+
template <typename SCEVTy, typename Op0_t, typename Op1_t,
187+
bool Commutable = false>
187188
struct SCEVBinaryExpr_match {
188189
Op0_t Op0;
189190
Op1_t Op1;
@@ -192,15 +193,18 @@ struct SCEVBinaryExpr_match {
192193

193194
bool match(const SCEV *S) const {
194195
auto *E = dyn_cast<SCEVTy>(S);
195-
return E && E->getNumOperands() == 2 && Op0.match(E->getOperand(0)) &&
196-
Op1.match(E->getOperand(1));
196+
return E && E->getNumOperands() == 2 &&
197+
((Op0.match(E->getOperand(0)) && Op1.match(E->getOperand(1))) ||
198+
(Commutable && Op0.match(E->getOperand(1)) &&
199+
Op1.match(E->getOperand(0))));
197200
}
198201
};
199202

200-
template <typename SCEVTy, typename Op0_t, typename Op1_t>
201-
inline SCEVBinaryExpr_match<SCEVTy, Op0_t, Op1_t>
203+
template <typename SCEVTy, typename Op0_t, typename Op1_t,
204+
bool Commutable = false>
205+
inline SCEVBinaryExpr_match<SCEVTy, Op0_t, Op1_t, Commutable>
202206
m_scev_Binary(const Op0_t &Op0, const Op1_t &Op1) {
203-
return SCEVBinaryExpr_match<SCEVTy, Op0_t, Op1_t>(Op0, Op1);
207+
return SCEVBinaryExpr_match<SCEVTy, Op0_t, Op1_t, Commutable>(Op0, Op1);
204208
}
205209

206210
template <typename Op0_t, typename Op1_t>
@@ -215,6 +219,12 @@ m_scev_Mul(const Op0_t &Op0, const Op1_t &Op1) {
215219
return m_scev_Binary<SCEVMulExpr>(Op0, Op1);
216220
}
217221

222+
template <typename Op0_t, typename Op1_t>
223+
inline SCEVBinaryExpr_match<SCEVMulExpr, Op0_t, Op1_t, true>
224+
m_scev_c_Mul(const Op0_t &Op0, const Op1_t &Op1) {
225+
return m_scev_Binary<SCEVMulExpr, Op0_t, Op1_t, true>(Op0, Op1);
226+
}
227+
218228
template <typename Op0_t, typename Op1_t>
219229
inline SCEVBinaryExpr_match<SCEVUDivExpr, Op0_t, Op1_t>
220230
m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) {

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10785,6 +10785,25 @@ bool ScalarEvolution::SimplifyICmpOperands(CmpPredicate &Pred, const SCEV *&LHS,
1078510785
if (Depth >= 3)
1078610786
return false;
1078710787

10788+
const SCEV *NewLHS, *NewRHS;
10789+
if (match(LHS, m_scev_c_Mul(m_SCEV(NewLHS), m_SCEVVScale())) &&
10790+
match(RHS, m_scev_c_Mul(m_SCEV(NewRHS), m_SCEVVScale()))) {
10791+
const SCEVMulExpr *LMul = cast<SCEVMulExpr>(LHS);
10792+
const SCEVMulExpr *RMul = cast<SCEVMulExpr>(RHS);
10793+
10794+
// (X * vscale) uicmp/eq/ne (Y * vscale) ==> X uicmp/eq/ne Y
10795+
// when neither multiply wraps.
10796+
// (X * vscale) sicmp (Y * vscale) ==> X sicmp Y
10797+
// when neither multiply changes sign.
10798+
if ((LMul->hasNoSignedWrap() && RMul->hasNoSignedWrap()) ||
10799+
(LMul->hasNoUnsignedWrap() && RMul->hasNoUnsignedWrap() &&
10800+
!ICmpInst::isSigned(Pred))) {
10801+
LHS = NewLHS;
10802+
RHS = NewRHS;
10803+
Changed = true;
10804+
}
10805+
}
10806+
1078810807
// Canonicalize a constant to the right side.
1078910808
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
1079010809
// Check for both operands constant.
@@ -10959,7 +10978,7 @@ bool ScalarEvolution::SimplifyICmpOperands(CmpPredicate &Pred, const SCEV *&LHS,
1095910978
// Recursively simplify until we either hit a recursion limit or nothing
1096010979
// changes.
1096110980
if (Changed)
10962-
return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
10981+
return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1), true;
1096310982

1096410983
return Changed;
1096510984
}

llvm/test/Transforms/LoopVectorize/AArch64/sve-vscale-based-trip-counts.ll

Lines changed: 21 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ define void @vscale_mul_4(ptr noalias noundef readonly captures(none) %a, ptr no
99
; CHECK-NEXT: [[ENTRY:.*]]:
1010
; CHECK-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64()
1111
; CHECK-NEXT: [[TMP1:%.*]] = shl nuw nsw i64 [[TMP0]], 2
12-
; CHECK-NEXT: [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
13-
; CHECK-NEXT: [[TMP5:%.*]] = mul nuw i64 [[TMP4]], 4
14-
; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[TMP1]], [[TMP5]]
12+
; CHECK-NEXT: [[TMP10:%.*]] = call i64 @llvm.vscale.i64()
13+
; CHECK-NEXT: [[TMP3:%.*]] = mul nuw i64 [[TMP10]], 4
14+
; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[TMP1]], [[TMP3]]
1515
; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[TMP1]], [[N_MOD_VF]]
1616
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <vscale x 4 x float>, ptr [[A]], align 4
1717
; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = load <vscale x 4 x float>, ptr [[B]], align 4
18-
; CHECK-NEXT: [[TMP10:%.*]] = fmul <vscale x 4 x float> [[WIDE_LOAD]], [[WIDE_LOAD1]]
19-
; CHECK-NEXT: store <vscale x 4 x float> [[TMP10]], ptr [[B]], align 4
18+
; CHECK-NEXT: [[TMP4:%.*]] = fmul <vscale x 4 x float> [[WIDE_LOAD]], [[WIDE_LOAD1]]
19+
; CHECK-NEXT: store <vscale x 4 x float> [[TMP4]], ptr [[B]], align 4
2020
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]]
2121
; CHECK-NEXT: br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[FOR_BODY:.*]]
2222
; CHECK: [[FOR_COND_CLEANUP]]:
@@ -124,36 +124,29 @@ define void @vscale_mul_12(ptr noalias noundef readonly captures(none) %a, ptr n
124124
; CHECK-NEXT: [[ENTRY:.*]]:
125125
; CHECK-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64()
126126
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i64 [[TMP0]], 12
127-
; CHECK-NEXT: [[TMP2:%.*]] = shl nuw nsw i64 [[TMP0]], 2
128-
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[MUL1]], [[TMP2]]
129-
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
130-
; CHECK: [[VECTOR_PH]]:
131127
; CHECK-NEXT: [[TMP3:%.*]] = call i64 @llvm.vscale.i64()
132128
; CHECK-NEXT: [[TMP4:%.*]] = mul nuw i64 [[TMP3]], 4
133129
; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[MUL1]], [[TMP4]]
134130
; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[MUL1]], [[N_MOD_VF]]
135131
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
136132
; CHECK: [[VECTOR_BODY]]:
137-
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
133+
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
138134
; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds nuw float, ptr [[A]], i64 [[INDEX]]
139-
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <vscale x 4 x float>, ptr [[TMP7]], align 4
140-
; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds nuw float, ptr [[B]], i64 [[INDEX]]
141-
; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = load <vscale x 4 x float>, ptr [[TMP9]], align 4
142-
; CHECK-NEXT: [[TMP11:%.*]] = fmul <vscale x 4 x float> [[WIDE_LOAD]], [[WIDE_LOAD1]]
143-
; CHECK-NEXT: store <vscale x 4 x float> [[TMP11]], ptr [[TMP9]], align 4
135+
; CHECK-NEXT: [[WIDE_LOAD2:%.*]] = load <vscale x 4 x float>, ptr [[TMP7]], align 4
136+
; CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds nuw float, ptr [[B]], i64 [[INDEX]]
137+
; CHECK-NEXT: [[WIDE_LOAD4:%.*]] = load <vscale x 4 x float>, ptr [[TMP12]], align 4
138+
; CHECK-NEXT: [[TMP25:%.*]] = fmul <vscale x 4 x float> [[WIDE_LOAD2]], [[WIDE_LOAD4]]
139+
; CHECK-NEXT: store <vscale x 4 x float> [[TMP25]], ptr [[TMP12]], align 4
144140
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP4]]
145-
; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
146-
; CHECK-NEXT: br i1 [[TMP12]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
141+
; CHECK-NEXT: [[TMP22:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
142+
; CHECK-NEXT: br i1 [[TMP22]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
147143
; CHECK: [[MIDDLE_BLOCK]]:
148144
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[MUL1]], [[N_VEC]]
149-
; CHECK-NEXT: br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[SCALAR_PH]]
150-
; CHECK: [[SCALAR_PH]]:
151-
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
152-
; CHECK-NEXT: br label %[[FOR_BODY:.*]]
145+
; CHECK-NEXT: br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[FOR_BODY:.*]]
153146
; CHECK: [[FOR_COND_CLEANUP]]:
154147
; CHECK-NEXT: ret void
155148
; CHECK: [[FOR_BODY]]:
156-
; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], %[[FOR_BODY]] ]
149+
; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], %[[FOR_BODY]] ], [ [[N_VEC]], %[[MIDDLE_BLOCK]] ]
157150
; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds nuw float, ptr [[A]], i64 [[INDVARS_IV]]
158151
; CHECK-NEXT: [[TMP13:%.*]] = load float, ptr [[ARRAYIDX]], align 4
159152
; CHECK-NEXT: [[ARRAYIDX4:%.*]] = getelementptr inbounds nuw float, ptr [[B]], i64 [[INDVARS_IV]]
@@ -191,17 +184,13 @@ define void @vscale_mul_31(ptr noalias noundef readonly captures(none) %a, ptr n
191184
; CHECK-NEXT: [[ENTRY:.*]]:
192185
; CHECK-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64()
193186
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i64 [[TMP0]], 31
194-
; CHECK-NEXT: [[TMP2:%.*]] = shl nuw nsw i64 [[TMP0]], 3
195-
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[MUL1]], [[TMP2]]
196-
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
197-
; CHECK: [[VECTOR_PH]]:
198187
; CHECK-NEXT: [[TMP3:%.*]] = call i64 @llvm.vscale.i64()
199188
; CHECK-NEXT: [[TMP4:%.*]] = mul nuw i64 [[TMP3]], 8
200189
; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[MUL1]], [[TMP4]]
201190
; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[MUL1]], [[N_MOD_VF]]
202191
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
203192
; CHECK: [[VECTOR_BODY]]:
204-
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
193+
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
205194
; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds nuw float, ptr [[A]], i64 [[INDEX]]
206195
; CHECK-NEXT: [[TMP9:%.*]] = call i64 @llvm.vscale.i64()
207196
; CHECK-NEXT: [[TMP10:%.*]] = shl nuw i64 [[TMP9]], 2
@@ -226,14 +215,11 @@ define void @vscale_mul_31(ptr noalias noundef readonly captures(none) %a, ptr n
226215
; CHECK-NEXT: br i1 [[TMP22]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
227216
; CHECK: [[MIDDLE_BLOCK]]:
228217
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[MUL1]], [[N_VEC]]
229-
; CHECK-NEXT: br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[SCALAR_PH]]
230-
; CHECK: [[SCALAR_PH]]:
231-
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
232-
; CHECK-NEXT: br label %[[FOR_BODY:.*]]
218+
; CHECK-NEXT: br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[FOR_BODY:.*]]
233219
; CHECK: [[FOR_COND_CLEANUP]]:
234220
; CHECK-NEXT: ret void
235221
; CHECK: [[FOR_BODY]]:
236-
; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], %[[FOR_BODY]] ]
222+
; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], %[[FOR_BODY]] ], [ [[N_VEC]], %[[MIDDLE_BLOCK]] ]
237223
; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds nuw float, ptr [[A]], i64 [[INDVARS_IV]]
238224
; CHECK-NEXT: [[TMP23:%.*]] = load float, ptr [[ARRAYIDX]], align 4
239225
; CHECK-NEXT: [[ARRAYIDX4:%.*]] = getelementptr inbounds nuw float, ptr [[B]], i64 [[INDVARS_IV]]
@@ -271,17 +257,13 @@ define void @vscale_mul_64(ptr noalias noundef readonly captures(none) %a, ptr n
271257
; CHECK-NEXT: [[ENTRY:.*]]:
272258
; CHECK-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64()
273259
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i64 [[TMP0]], 64
274-
; CHECK-NEXT: [[TMP2:%.*]] = shl nuw nsw i64 [[TMP0]], 3
275-
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[MUL1]], [[TMP2]]
276-
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
277-
; CHECK: [[VECTOR_PH]]:
278260
; CHECK-NEXT: [[TMP3:%.*]] = call i64 @llvm.vscale.i64()
279261
; CHECK-NEXT: [[TMP4:%.*]] = mul nuw i64 [[TMP3]], 8
280262
; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[MUL1]], [[TMP4]]
281263
; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[MUL1]], [[N_MOD_VF]]
282264
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
283265
; CHECK: [[VECTOR_BODY]]:
284-
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
266+
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
285267
; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds nuw float, ptr [[A]], i64 [[INDEX]]
286268
; CHECK-NEXT: [[TMP9:%.*]] = call i64 @llvm.vscale.i64()
287269
; CHECK-NEXT: [[TMP10:%.*]] = shl nuw i64 [[TMP9]], 2
@@ -306,14 +288,11 @@ define void @vscale_mul_64(ptr noalias noundef readonly captures(none) %a, ptr n
306288
; CHECK-NEXT: br i1 [[TMP22]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
307289
; CHECK: [[MIDDLE_BLOCK]]:
308290
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[MUL1]], [[N_VEC]]
309-
; CHECK-NEXT: br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[SCALAR_PH]]
310-
; CHECK: [[SCALAR_PH]]:
311-
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
312-
; CHECK-NEXT: br label %[[FOR_BODY:.*]]
291+
; CHECK-NEXT: br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[FOR_BODY:.*]]
313292
; CHECK: [[FOR_COND_CLEANUP]]:
314293
; CHECK-NEXT: ret void
315294
; CHECK: [[FOR_BODY]]:
316-
; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], %[[FOR_BODY]] ]
295+
; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], %[[FOR_BODY]] ], [ [[N_VEC]], %[[MIDDLE_BLOCK]] ]
317296
; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds nuw float, ptr [[A]], i64 [[INDVARS_IV]]
318297
; CHECK-NEXT: [[TMP23:%.*]] = load float, ptr [[ARRAYIDX]], align 4
319298
; CHECK-NEXT: [[ARRAYIDX4:%.*]] = getelementptr inbounds nuw float, ptr [[B]], i64 [[INDVARS_IV]]

llvm/unittests/Analysis/ScalarEvolutionTest.cpp

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1768,4 +1768,141 @@ TEST_F(ScalarEvolutionsTest, ComplexityComparatorIsStrictWeakOrdering3) {
17681768
SE.getSCEV(Or1);
17691769
}
17701770

1771+
TEST_F(ScalarEvolutionsTest, SimplifyICmpOperands) {
1772+
LLVMContext C;
1773+
SMDiagnostic Err;
1774+
std::unique_ptr<Module> M =
1775+
parseAssemblyString("define i32 @foo(ptr %loc, i32 %a, i32 %b) {"
1776+
"entry: "
1777+
" ret i32 %a "
1778+
"} ",
1779+
Err, C);
1780+
1781+
ASSERT_TRUE(M && "Could not parse module?");
1782+
ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
1783+
1784+
// Remove common factor when there's no signed wrapping.
1785+
runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
1786+
const SCEV *A = SE.getSCEV(getArgByName(F, "a"));
1787+
const SCEV *B = SE.getSCEV(getArgByName(F, "b"));
1788+
const SCEV *VS = SE.getVScale(A->getType());
1789+
const SCEV *VSxA = SE.getMulExpr(VS, A, SCEV::FlagNSW);
1790+
const SCEV *VSxB = SE.getMulExpr(VS, B, SCEV::FlagNSW);
1791+
1792+
{
1793+
CmpPredicate NewPred = ICmpInst::ICMP_SLT;
1794+
const SCEV *NewLHS = VSxA;
1795+
const SCEV *NewRHS = VSxB;
1796+
EXPECT_TRUE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
1797+
EXPECT_EQ(NewPred, ICmpInst::ICMP_SLT);
1798+
EXPECT_EQ(NewLHS, A);
1799+
EXPECT_EQ(NewRHS, B);
1800+
}
1801+
1802+
{
1803+
CmpPredicate NewPred = ICmpInst::ICMP_ULT;
1804+
const SCEV *NewLHS = VSxA;
1805+
const SCEV *NewRHS = VSxB;
1806+
EXPECT_TRUE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
1807+
EXPECT_EQ(NewPred, ICmpInst::ICMP_ULT);
1808+
EXPECT_EQ(NewLHS, A);
1809+
EXPECT_EQ(NewRHS, B);
1810+
}
1811+
1812+
{
1813+
CmpPredicate NewPred = ICmpInst::ICMP_EQ;
1814+
const SCEV *NewLHS = VSxA;
1815+
const SCEV *NewRHS = VSxB;
1816+
EXPECT_TRUE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
1817+
EXPECT_EQ(NewPred, ICmpInst::ICMP_EQ);
1818+
EXPECT_EQ(NewLHS, A);
1819+
EXPECT_EQ(NewRHS, B);
1820+
}
1821+
1822+
// Verify the common factor's position doesn't impede simplification.
1823+
{
1824+
const SCEV *C = SE.getConstant(A->getType(), 100);
1825+
const SCEV *CxVS = SE.getMulExpr(C, VS, SCEV::FlagNSW);
1826+
1827+
// Verify common factor is available at different indices.
1828+
ASSERT_TRUE(isa<SCEVVScale>(cast<SCEVMulExpr>(VSxA)->getOperand(0)) !=
1829+
isa<SCEVVScale>(cast<SCEVMulExpr>(CxVS)->getOperand(0)));
1830+
1831+
CmpPredicate NewPred = ICmpInst::ICMP_SLT;
1832+
const SCEV *NewLHS = VSxA;
1833+
const SCEV *NewRHS = CxVS;
1834+
EXPECT_TRUE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
1835+
EXPECT_EQ(NewPred, ICmpInst::ICMP_SLT);
1836+
EXPECT_EQ(NewLHS, A);
1837+
EXPECT_EQ(NewRHS, C);
1838+
}
1839+
});
1840+
1841+
// Remove common factor when there's no unsigned wrapping.
1842+
runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
1843+
const SCEV *A = SE.getSCEV(getArgByName(F, "a"));
1844+
const SCEV *B = SE.getSCEV(getArgByName(F, "b"));
1845+
const SCEV *VS = SE.getVScale(A->getType());
1846+
const SCEV *VSxA = SE.getMulExpr(VS, A, SCEV::FlagNUW);
1847+
const SCEV *VSxB = SE.getMulExpr(VS, B, SCEV::FlagNUW);
1848+
1849+
{
1850+
CmpPredicate NewPred = ICmpInst::ICMP_SLT;
1851+
const SCEV *NewLHS = VSxA;
1852+
const SCEV *NewRHS = VSxB;
1853+
EXPECT_FALSE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
1854+
}
1855+
1856+
{
1857+
CmpPredicate NewPred = ICmpInst::ICMP_ULT;
1858+
const SCEV *NewLHS = VSxA;
1859+
const SCEV *NewRHS = VSxB;
1860+
EXPECT_TRUE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
1861+
EXPECT_EQ(NewPred, ICmpInst::ICMP_ULT);
1862+
EXPECT_EQ(NewLHS, A);
1863+
EXPECT_EQ(NewRHS, B);
1864+
}
1865+
1866+
{
1867+
CmpPredicate NewPred = ICmpInst::ICMP_EQ;
1868+
const SCEV *NewLHS = VSxA;
1869+
const SCEV *NewRHS = VSxB;
1870+
EXPECT_TRUE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
1871+
EXPECT_EQ(NewPred, ICmpInst::ICMP_EQ);
1872+
EXPECT_EQ(NewLHS, A);
1873+
EXPECT_EQ(NewRHS, B);
1874+
}
1875+
});
1876+
1877+
// Do not remove common factor due to wrap flag mismatch.
1878+
runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
1879+
const SCEV *A = SE.getSCEV(getArgByName(F, "a"));
1880+
const SCEV *B = SE.getSCEV(getArgByName(F, "b"));
1881+
const SCEV *VS = SE.getVScale(A->getType());
1882+
const SCEV *VSxA = SE.getMulExpr(VS, A, SCEV::FlagNSW);
1883+
const SCEV *VSxB = SE.getMulExpr(VS, B, SCEV::FlagNUW);
1884+
1885+
{
1886+
CmpPredicate NewPred = ICmpInst::ICMP_SLT;
1887+
const SCEV *NewLHS = VSxA;
1888+
const SCEV *NewRHS = VSxB;
1889+
EXPECT_FALSE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
1890+
}
1891+
1892+
{
1893+
CmpPredicate NewPred = ICmpInst::ICMP_ULT;
1894+
const SCEV *NewLHS = VSxA;
1895+
const SCEV *NewRHS = VSxB;
1896+
EXPECT_FALSE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
1897+
}
1898+
1899+
{
1900+
CmpPredicate NewPred = ICmpInst::ICMP_EQ;
1901+
const SCEV *NewLHS = VSxA;
1902+
const SCEV *NewRHS = VSxB;
1903+
EXPECT_FALSE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
1904+
}
1905+
});
1906+
}
1907+
17711908
} // end namespace llvm

0 commit comments

Comments
 (0)