7979#include "llvm/Analysis/LoopInfo.h"
8080#include "llvm/Analysis/MemoryBuiltins.h"
8181#include "llvm/Analysis/ScalarEvolutionExpressions.h"
82+ #include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
8283#include "llvm/Analysis/TargetLibraryInfo.h"
8384#include "llvm/Analysis/ValueTracking.h"
8485#include "llvm/Config/llvm-config.h"
133134
134135using namespace llvm;
135136using namespace PatternMatch;
137+ using namespace SCEVPatternMatch;
136138
137139#define DEBUG_TYPE "scalar-evolution"
138140
@@ -443,23 +445,11 @@ ArrayRef<const SCEV *> SCEV::operands() const {
443445 llvm_unreachable("Unknown SCEV kind!");
444446}
445447
446- bool SCEV::isZero() const {
447- if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
448- return SC->getValue()->isZero();
449- return false;
450- }
448+ bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
451449
452- bool SCEV::isOne() const {
453- if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
454- return SC->getValue()->isOne();
455- return false;
456- }
450+ bool SCEV::isOne() const { return match(this, m_scev_One()); }
457451
458- bool SCEV::isAllOnesValue() const {
459- if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
460- return SC->getValue()->isMinusOne();
461- return false;
462- }
452+ bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
463453
464454bool SCEV::isNonConstantNegative() const {
465455 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
@@ -3423,9 +3413,8 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
34233413 return S;
34243414
34253415 // 0 udiv Y == 0
3426- if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3427- if (LHSC->getValue()->isZero())
3428- return LHS;
3416+ if (match(LHS, m_scev_Zero()))
3417+ return LHS;
34293418
34303419 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
34313420 if (RHSC->getValue()->isOne())
@@ -10593,7 +10582,6 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
1059310582 // Get the initial value for the loop.
1059410583 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
1059510584 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10596- const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
1059710585
1059810586 if (!isLoopInvariant(Step, L))
1059910587 return getCouldNotCompute();
@@ -10615,8 +10603,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
1061510603 // Handle unitary steps, which cannot wraparound.
1061610604 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
1061710605 // N = Distance (as unsigned)
10618- if (StepC &&
10619- (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne( ))) {
10606+
10607+ if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes() ))) {
1062010608 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
1062110609 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
1062210610
@@ -10668,6 +10656,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
1066810656 }
1066910657
1067010658 // Solve the general equation.
10659+ const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
1067110660 if (!StepC || StepC->getValue()->isZero())
1067210661 return getCouldNotCompute();
1067310662 const SCEV *E = SolveLinEquationWithOverflow(
@@ -15510,9 +15499,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1551015499
1551115500 // If we have LHS == 0, check if LHS is computing a property of some unknown
1551215501 // SCEV %v which we can rewrite %v to express explicitly.
15513- const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
15514- if (Predicate == CmpInst::ICMP_EQ && RHSC &&
15515- RHSC->getValue()->isNullValue()) {
15502+ if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
1551615503 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
1551715504 // explicitly express that.
1551815505 const SCEV *URemLHS = nullptr;
@@ -15693,8 +15680,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1569315680 To = RHS;
1569415681 break;
1569515682 case CmpInst::ICMP_NE:
15696- if (isa<SCEVConstant>(RHS) &&
15697- cast<SCEVConstant>(RHS)->getValue()->isNullValue()) {
15683+ if (match(RHS, m_scev_Zero())) {
1569815684 const SCEV *OneAlignedUp =
1569915685 DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
1570015686 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
0 commit comments