Skip to content

Commit a0dc428

Browse files
committed
!fixup add unary matcher for ZExt and SExt
1 parent ecaa67e commit a0dc428

File tree

2 files changed

+41
-14
lines changed

2 files changed

+41
-14
lines changed

llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,18 @@ inline bind_ty<const SCEVUnknown> m_SCEVUnknown(const SCEVUnknown *&V) {
7979
return V;
8080
}
8181

82+
/// Match a specified const SCEV *.
83+
struct specificscev_ty {
84+
const SCEV *Expr;
85+
86+
specificscev_ty(const SCEV *Expr) : Expr(Expr) {}
87+
88+
template <typename ITy> bool match(ITy *S) { return S == Expr; }
89+
};
90+
91+
/// Match if we have a specific specified SCEV.
92+
inline specificscev_ty m_Specific(const SCEV *S) { return S; }
93+
8294
namespace detail {
8395

8496
template <typename TupleTy, typename Fn, std::size_t... Is>
@@ -133,6 +145,26 @@ m_scev_Add(const Op0_t &Op0, const Op1_t &Op1) {
133145
return BinarySCEV_match<Op0_t, Op1_t, SCEVAddExpr>(Op0, Op1);
134146
}
135147

148+
template <typename Op0_t, typename SCEVTy>
149+
using UnarySCEV_match = SCEV_match<std::tuple<Op0_t>, SCEVTy>;
150+
151+
template <typename Op0_t, typename Op1_t, typename SCEVTy>
152+
inline UnarySCEV_match<Op0_t, SCEVTy> m_scev_Unary(const Op0_t &Op0) {
153+
return UnarySCEV_match<Op0_t, SCEVTy>(Op0);
154+
}
155+
156+
template <typename Op0_t>
157+
inline UnarySCEV_match<Op0_t, SCEVSignExtendExpr>
158+
m_scev_SExt(const Op0_t &Op0) {
159+
return UnarySCEV_match<Op0_t, SCEVSignExtendExpr>(Op0);
160+
}
161+
162+
template <typename Op0_t>
163+
inline UnarySCEV_match<Op0_t, SCEVZeroExtendExpr>
164+
m_scev_ZExt(const Op0_t &Op0) {
165+
return UnarySCEV_match<Op0_t, SCEVZeroExtendExpr>(Op0);
166+
}
167+
136168
} // namespace SCEVPatternMatch
137169
} // namespace llvm
138170

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12724,33 +12724,28 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
1272412724
static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred,
1272512725
const SCEV *LHS, const SCEV *RHS) {
1272612726
// zext x u<= sext x, sext x s<= zext x
12727+
const SCEV *Op;
1272712728
switch (Pred) {
1272812729
case ICmpInst::ICMP_SGE:
1272912730
std::swap(LHS, RHS);
1273012731
[[fallthrough]];
1273112732
case ICmpInst::ICMP_SLE: {
12732-
// If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12733-
const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(LHS);
12734-
const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(RHS);
12735-
if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
12736-
return true;
12737-
break;
12733+
// If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12734+
return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
12735+
match(RHS, m_scev_ZExt(m_Specific(Op)));
1273812736
}
1273912737
case ICmpInst::ICMP_UGE:
1274012738
std::swap(LHS, RHS);
1274112739
[[fallthrough]];
1274212740
case ICmpInst::ICMP_ULE: {
12743-
// If operand >=s 0 then ZExt == SExt. If operand <s 0 then ZExt <u SExt.
12744-
const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS);
12745-
const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(RHS);
12746-
if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
12747-
return true;
12748-
break;
12741+
// If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
12742+
return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
12743+
match(RHS, m_scev_SExt(m_Specific(Op)));
1274912744
}
1275012745
default:
12751-
break;
12746+
return false;
1275212747
};
12753-
return false;
12748+
llvm_unreachable("unhandled case");
1275412749
}
1275512750

1275612751
bool

0 commit comments

Comments
 (0)