@@ -6351,61 +6351,62 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6351
6351
return getGEPExpr(GEP, IndexExprs);
6352
6352
}
6353
6353
6354
- APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
6354
+ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
6355
+ const Instruction *CtxI) {
6355
6356
uint64_t BitWidth = getTypeSizeInBits(S->getType());
6356
6357
auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6357
6358
return TrailingZeros >= BitWidth
6358
6359
? APInt::getZero(BitWidth)
6359
6360
: APInt::getOneBitSet(BitWidth, TrailingZeros);
6360
6361
};
6361
- auto GetGCDMultiple = [this](const SCEVNAryExpr *N) {
6362
+ auto GetGCDMultiple = [this, CtxI ](const SCEVNAryExpr *N) {
6362
6363
// The result is GCD of all operands results.
6363
- APInt Res = getConstantMultiple(N->getOperand(0));
6364
+ APInt Res = getConstantMultiple(N->getOperand(0), CtxI );
6364
6365
for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6365
6366
Res = APIntOps::GreatestCommonDivisor(
6366
- Res, getConstantMultiple(N->getOperand(I)));
6367
+ Res, getConstantMultiple(N->getOperand(I), CtxI ));
6367
6368
return Res;
6368
6369
};
6369
6370
6370
6371
switch (S->getSCEVType()) {
6371
6372
case scConstant:
6372
6373
return cast<SCEVConstant>(S)->getAPInt();
6373
6374
case scPtrToInt:
6374
- return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand());
6375
+ return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand(), CtxI );
6375
6376
case scUDivExpr:
6376
6377
case scVScale:
6377
6378
return APInt(BitWidth, 1);
6378
6379
case scTruncate: {
6379
6380
// Only multiples that are a power of 2 will hold after truncation.
6380
6381
const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6381
- uint32_t TZ = getMinTrailingZeros(T->getOperand());
6382
+ uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI );
6382
6383
return GetShiftedByZeros(TZ);
6383
6384
}
6384
6385
case scZeroExtend: {
6385
6386
const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6386
- return getConstantMultiple(Z->getOperand()).zext(BitWidth);
6387
+ return getConstantMultiple(Z->getOperand(), CtxI ).zext(BitWidth);
6387
6388
}
6388
6389
case scSignExtend: {
6389
6390
// Only multiples that are a power of 2 will hold after sext.
6390
6391
const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6391
- uint32_t TZ = getMinTrailingZeros(E->getOperand());
6392
+ uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI );
6392
6393
return GetShiftedByZeros(TZ);
6393
6394
}
6394
6395
case scMulExpr: {
6395
6396
const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6396
6397
if (M->hasNoUnsignedWrap()) {
6397
6398
// The result is the product of all operand results.
6398
- APInt Res = getConstantMultiple(M->getOperand(0));
6399
+ APInt Res = getConstantMultiple(M->getOperand(0), CtxI );
6399
6400
for (const SCEV *Operand : M->operands().drop_front())
6400
- Res = Res * getConstantMultiple(Operand);
6401
+ Res = Res * getConstantMultiple(Operand, CtxI );
6401
6402
return Res;
6402
6403
}
6403
6404
6404
6405
// If there are no wrap guarentees, find the trailing zeros, which is the
6405
6406
// sum of trailing zeros for all its operands.
6406
6407
uint32_t TZ = 0;
6407
6408
for (const SCEV *Operand : M->operands())
6408
- TZ += getMinTrailingZeros(Operand);
6409
+ TZ += getMinTrailingZeros(Operand, CtxI );
6409
6410
return GetShiftedByZeros(TZ);
6410
6411
}
6411
6412
case scAddExpr:
@@ -6414,9 +6415,9 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
6414
6415
if (N->hasNoUnsignedWrap())
6415
6416
return GetGCDMultiple(N);
6416
6417
// Find the trailing bits, which is the minimum of its operands.
6417
- uint32_t TZ = getMinTrailingZeros(N->getOperand(0));
6418
+ uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI );
6418
6419
for (const SCEV *Operand : N->operands().drop_front())
6419
- TZ = std::min(TZ, getMinTrailingZeros(Operand));
6420
+ TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI ));
6420
6421
return GetShiftedByZeros(TZ);
6421
6422
}
6422
6423
case scUMaxExpr:
@@ -6429,7 +6430,7 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
6429
6430
// ask ValueTracking for known bits
6430
6431
const SCEVUnknown *U = cast<SCEVUnknown>(S);
6431
6432
unsigned Known =
6432
- computeKnownBits(U->getValue(), getDataLayout(), &AC, nullptr , &DT)
6433
+ computeKnownBits(U->getValue(), getDataLayout(), &AC, CtxI , &DT)
6433
6434
.countMinTrailingZeros();
6434
6435
return GetShiftedByZeros(Known);
6435
6436
}
@@ -6439,12 +6440,18 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
6439
6440
llvm_unreachable("Unknown SCEV kind!");
6440
6441
}
6441
6442
6442
- APInt ScalarEvolution::getConstantMultiple(const SCEV *S) {
6443
+ APInt ScalarEvolution::getConstantMultiple(const SCEV *S,
6444
+ const Instruction *CtxI) {
6445
+ // Skip looking up and updating the cache if there is a context instruction,
6446
+ // as the result will only be valid in the specified context.
6447
+ if (CtxI)
6448
+ return getConstantMultipleImpl(S, CtxI);
6449
+
6443
6450
auto I = ConstantMultipleCache.find(S);
6444
6451
if (I != ConstantMultipleCache.end())
6445
6452
return I->second;
6446
6453
6447
- APInt Result = getConstantMultipleImpl(S);
6454
+ APInt Result = getConstantMultipleImpl(S, CtxI );
6448
6455
auto InsertPair = ConstantMultipleCache.insert({S, Result});
6449
6456
assert(InsertPair.second && "Should insert a new key");
6450
6457
return InsertPair.first->second;
@@ -6455,8 +6462,9 @@ APInt ScalarEvolution::getNonZeroConstantMultiple(const SCEV *S) {
6455
6462
return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6456
6463
}
6457
6464
6458
- uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S) {
6459
- return std::min(getConstantMultiple(S).countTrailingZeros(),
6465
+ uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S,
6466
+ const Instruction *CtxI) {
6467
+ return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
6460
6468
(unsigned)getTypeSizeInBits(S->getType()));
6461
6469
}
6462
6470
@@ -10243,8 +10251,7 @@ const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10243
10251
static const SCEV *
10244
10252
SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
10245
10253
SmallVectorImpl<const SCEVPredicate *> *Predicates,
10246
-
10247
- ScalarEvolution &SE) {
10254
+ ScalarEvolution &SE, const Loop *L) {
10248
10255
uint32_t BW = A.getBitWidth();
10249
10256
assert(BW == SE.getTypeSizeInBits(B->getType()));
10250
10257
assert(A != 0 && "A must be non-zero.");
@@ -10260,7 +10267,12 @@ SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
10260
10267
//
10261
10268
// B is divisible by D if and only if the multiplicity of prime factor 2 for B
10262
10269
// is not less than multiplicity of this prime factor for D.
10263
- if (SE.getMinTrailingZeros(B) < Mult2) {
10270
+ unsigned MinTZ = SE.getMinTrailingZeros(B);
10271
+ // Try again with the terminator of the loop predecessor for context-specific
10272
+ // result, if MinTZ s too small.
10273
+ if (MinTZ < Mult2 && L->getLoopPredecessor())
10274
+ MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
10275
+ if (MinTZ < Mult2) {
10264
10276
// Check if we can prove there's no remainder using URem.
10265
10277
const SCEV *URem =
10266
10278
SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
@@ -10708,7 +10720,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10708
10720
return getCouldNotCompute();
10709
10721
const SCEV *E = SolveLinEquationWithOverflow(
10710
10722
StepC->getAPInt(), getNegativeSCEV(Start),
10711
- AllowPredicates ? &Predicates : nullptr, *this);
10723
+ AllowPredicates ? &Predicates : nullptr, *this, L );
10712
10724
10713
10725
const SCEV *M = E;
10714
10726
if (E != getCouldNotCompute()) {
0 commit comments