Skip to content

Commit a783763

Browse files
nikicakiramenai
authored andcommitted
[SCEV] Handle more adds in computeConstantDifference() (#101339)
Currently it only deals with the case where we're subtracting adds with at most one non-constant operand. This patch extends it to cancel out common operands for the subtraction of arbitrary add expressions. The background here is that I want to replace a getMinusSCEV() call in LAA with computeConstantDifference(): https://github.com/llvm/llvm-project/blob/93fecc2577ece0329f3bbe2719bbc5b4b9b30010/llvm/lib/Analysis/LoopAccessAnalysis.cpp#L1602-L1603 This particular call is very expensive in some cases (e.g. lencod with LTO) and computeConstantDifference() could achieve this much more cheaply, because it does not need to construct new SCEV expressions. However, the current computeConstantDifference() implementation is too weak for this and misses many basic cases. This is a step towards making it more powerful while still keeping it pretty fast.
1 parent 605aa52 commit a783763

File tree

2 files changed

+32
-27
lines changed

2 files changed

+32
-27
lines changed

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11947,8 +11947,9 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
1194711947
// fairly deep in the call stack (i.e. is called many times).
1194811948

1194911949
// X - X = 0.
11950+
unsigned BW = getTypeSizeInBits(More->getType());
1195011951
if (More == Less)
11951-
return APInt(getTypeSizeInBits(More->getType()), 0);
11952+
return APInt(BW, 0);
1195211953

1195311954
if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
1195411955
const auto *LAR = cast<SCEVAddRecExpr>(Less);
@@ -11971,33 +11972,31 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
1197111972
// fall through
1197211973
}
1197311974

11974-
if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {
11975-
const auto &M = cast<SCEVConstant>(More)->getAPInt();
11976-
const auto &L = cast<SCEVConstant>(Less)->getAPInt();
11977-
return M - L;
11978-
}
11979-
11980-
SCEV::NoWrapFlags Flags;
11981-
const SCEV *LLess = nullptr, *RLess = nullptr;
11982-
const SCEV *LMore = nullptr, *RMore = nullptr;
11983-
const SCEVConstant *C1 = nullptr, *C2 = nullptr;
11984-
// Compare (X + C1) vs X.
11985-
if (splitBinaryAdd(Less, LLess, RLess, Flags))
11986-
if ((C1 = dyn_cast<SCEVConstant>(LLess)))
11987-
if (RLess == More)
11988-
return -(C1->getAPInt());
11989-
11990-
// Compare X vs (X + C2).
11991-
if (splitBinaryAdd(More, LMore, RMore, Flags))
11992-
if ((C2 = dyn_cast<SCEVConstant>(LMore)))
11993-
if (RMore == Less)
11994-
return C2->getAPInt();
11975+
// Try to cancel out common factors in two add expressions.
11976+
SmallDenseMap<const SCEV *, int, 8> Multiplicity;
11977+
APInt Diff(BW, 0);
11978+
auto Add = [&](const SCEV *S, int Mul) {
11979+
if (auto *C = dyn_cast<SCEVConstant>(S))
11980+
Diff += C->getAPInt() * Mul;
11981+
else
11982+
Multiplicity[S] += Mul;
11983+
};
11984+
auto Decompose = [&](const SCEV *S, int Mul) {
11985+
if (isa<SCEVAddExpr>(S)) {
11986+
for (const SCEV *Op : S->operands())
11987+
Add(Op, Mul);
11988+
} else
11989+
Add(S, Mul);
11990+
};
11991+
Decompose(More, 1);
11992+
Decompose(Less, -1);
1199511993

11996-
// Compare (X + C1) vs (X + C2).
11997-
if (C1 && C2 && RLess == RMore)
11998-
return C2->getAPInt() - C1->getAPInt();
11994+
// Check whether all the non-constants cancel out.
11995+
for (const auto [_, Mul] : Multiplicity)
11996+
if (Mul != 0)
11997+
return std::nullopt;
1199911998

12000-
return std::nullopt;
11999+
return Diff;
1200112000
}
1200212001

1200312002
bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(

llvm/unittests/Analysis/ScalarEvolutionTest.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1117,10 +1117,12 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
11171117
LLVMContext C;
11181118
SMDiagnostic Err;
11191119
std::unique_ptr<Module> M = parseAssemblyString(
1120-
"define void @foo(i32 %sz, i32 %pp) { "
1120+
"define void @foo(i32 %sz, i32 %pp, i32 %x) { "
11211121
"entry: "
11221122
" %v0 = add i32 %pp, 0 "
11231123
" %v3 = add i32 %pp, 3 "
1124+
" %vx = add i32 %pp, %x "
1125+
" %vx3 = add i32 %vx, 3 "
11241126
" br label %loop.body "
11251127
"loop.body: "
11261128
" %iv = phi i32 [ %iv.next, %loop.body ], [ 0, %entry ] "
@@ -1141,6 +1143,9 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
11411143
runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
11421144
auto *ScevV0 = SE.getSCEV(getInstructionByName(F, "v0")); // %pp
11431145
auto *ScevV3 = SE.getSCEV(getInstructionByName(F, "v3")); // (3 + %pp)
1146+
auto *ScevVX = SE.getSCEV(getInstructionByName(F, "vx")); // (%pp + %x)
1147+
// (%pp + %x + 3)
1148+
auto *ScevVX3 = SE.getSCEV(getInstructionByName(F, "vx3"));
11441149
auto *ScevIV = SE.getSCEV(getInstructionByName(F, "iv")); // {0,+,1}
11451150
auto *ScevXA = SE.getSCEV(getInstructionByName(F, "xa")); // {%pp,+,1}
11461151
auto *ScevYY = SE.getSCEV(getInstructionByName(F, "yy")); // {(3 + %pp),+,1}
@@ -1162,6 +1167,7 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
11621167
EXPECT_EQ(diff(ScevV0, ScevV3), -3);
11631168
EXPECT_EQ(diff(ScevV0, ScevV0), 0);
11641169
EXPECT_EQ(diff(ScevV3, ScevV3), 0);
1170+
EXPECT_EQ(diff(ScevVX3, ScevVX), 3);
11651171
EXPECT_EQ(diff(ScevIV, ScevIV), 0);
11661172
EXPECT_EQ(diff(ScevXA, ScevXB), 0);
11671173
EXPECT_EQ(diff(ScevXA, ScevYY), -3);

0 commit comments

Comments
 (0)