Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,101 @@ inline cst_pred_ty<is_all_ones> m_scev_AllOnes() {
return cst_pred_ty<is_all_ones>();
}

template <typename Class> struct class_match {
template <typename ITy> bool match(ITy *V) const { return isa<Class>(V); }
};

template <typename Class> struct bind_ty {
Class *&VR;

bind_ty(Class *&V) : VR(V) {}

template <typename ITy> bool match(ITy *V) const {
if (auto *CV = dyn_cast<Class>(V)) {
VR = CV;
return true;
}
return false;
}
};

/// Match a SCEV, capturing it if we match.
inline bind_ty<const SCEV> m_SCEV(const SCEV *&V) { return V; }
inline bind_ty<const SCEVConstant> m_SCEVConstant(const SCEVConstant *&V) {
return V;
}
inline bind_ty<const SCEVUnknown> m_SCEVUnknown(const SCEVUnknown *&V) {
return V;
}

/// Match a specified const SCEV *.
struct specificscev_ty {
const SCEV *Expr;

specificscev_ty(const SCEV *Expr) : Expr(Expr) {}

template <typename ITy> bool match(ITy *S) { return S == Expr; }
};

/// Match if we have a specific specified SCEV.
inline specificscev_ty m_Specific(const SCEV *S) { return S; }

/// Match a unary SCEV.
template <typename SCEVTy, typename Op0_t> struct SCEVUnaryExpr_match {
Op0_t Op0;

SCEVUnaryExpr_match(Op0_t Op0) : Op0(Op0) {}

bool match(const SCEV *S) {
auto *E = dyn_cast<SCEVTy>(S);
return E && E->getNumOperands() == 1 && Op0.match(E->getOperand(0));
}
};

template <typename SCEVTy, typename Op0_t>
inline SCEVUnaryExpr_match<SCEVTy, Op0_t> m_scev_Unary(const Op0_t &Op0) {
return SCEVUnaryExpr_match<SCEVTy, Op0_t>(Op0);
}

template <typename Op0_t>
inline SCEVUnaryExpr_match<SCEVSignExtendExpr, Op0_t>
m_scev_SExt(const Op0_t &Op0) {
return m_scev_Unary<SCEVSignExtendExpr>(Op0);
}

template <typename Op0_t>
inline SCEVUnaryExpr_match<SCEVZeroExtendExpr, Op0_t>
m_scev_ZExt(const Op0_t &Op0) {
return m_scev_Unary<SCEVZeroExtendExpr>(Op0);
}

/// Match a binary SCEV.
template <typename SCEVTy, typename Op0_t, typename Op1_t>
struct SCEVBinaryExpr_match {
Op0_t Op0;
Op1_t Op1;

SCEVBinaryExpr_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {}

bool match(const SCEV *S) {
auto *E = dyn_cast<SCEVTy>(S);
return E && E->getNumOperands() == 2 && Op0.match(E->getOperand(0)) &&
Op1.match(E->getOperand(1));
}
};

template <typename SCEVTy, typename Op0_t, typename Op1_t>
inline SCEVBinaryExpr_match<SCEVTy, Op0_t, Op1_t>
m_scev_Binary(const Op0_t &Op0, const Op1_t &Op1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are you supposed to actually use this one? Don't you have to make SCEVTy the first template parameter to make this ergnomic?

return SCEVBinaryExpr_match<SCEVTy, Op0_t, Op1_t>(Op0, Op1);
}

template <typename Op0_t, typename Op1_t>
inline SCEVBinaryExpr_match<SCEVAddExpr, Op0_t, Op1_t>
m_scev_Add(const Op0_t &Op0, const Op1_t &Op1) {
return m_scev_Binary<SCEVAddExpr>(Op0, Op1);
}

} // namespace SCEVPatternMatch
} // namespace llvm

Expand Down
35 changes: 14 additions & 21 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12725,33 +12725,28 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS) {
// zext x u<= sext x, sext x s<= zext x
const SCEV *Op;
switch (Pred) {
case ICmpInst::ICMP_SGE:
std::swap(LHS, RHS);
[[fallthrough]];
case ICmpInst::ICMP_SLE: {
// If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(LHS);
const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(RHS);
if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
return true;
break;
// If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
match(RHS, m_scev_ZExt(m_Specific(Op)));
}
case ICmpInst::ICMP_UGE:
std::swap(LHS, RHS);
[[fallthrough]];
case ICmpInst::ICMP_ULE: {
// If operand >=s 0 then ZExt == SExt. If operand <s 0 then ZExt <u SExt.
const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS);
const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(RHS);
if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
return true;
break;
// If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
match(RHS, m_scev_SExt(m_Specific(Op)));
}
default:
break;
return false;
};
return false;
llvm_unreachable("unhandled case");
}

bool
Expand Down Expand Up @@ -15417,14 +15412,12 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
// (X >=u C1).
auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
&ExprsToRewrite]() {
auto *AddExpr = dyn_cast<SCEVAddExpr>(LHS);
if (!AddExpr || AddExpr->getNumOperands() != 2)
return false;

auto *C1 = dyn_cast<SCEVConstant>(AddExpr->getOperand(0));
auto *LHSUnknown = dyn_cast<SCEVUnknown>(AddExpr->getOperand(1));
const SCEVConstant *C1;
const SCEVUnknown *LHSUnknown;
auto *C2 = dyn_cast<SCEVConstant>(RHS);
if (!C1 || !C2 || !LHSUnknown)
if (!match(LHS,
m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
!C2)
return false;

auto ExactRegion =
Expand Down
Loading