Skip to content

Commit 22352cc

Browse files
committed
[CodeGenPrepare] Create USubWithOverflow_match (NFC)
To make it consistent with m_UAddWithOverflow_match.
1 parent 2bc46ae commit 22352cc

File tree

3 files changed

+137
-15
lines changed

3 files changed

+137
-15
lines changed

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2685,6 +2685,81 @@ m_UAddWithOverflow(const LHS_t &L, const RHS_t &R, const Sum_t &S) {
26852685
return UAddWithOverflow_match<LHS_t, RHS_t, Sum_t>(L, R, S);
26862686
}
26872687

2688+
template <typename LHS_t, typename RHS_t, typename Diff_t>
2689+
struct USubWithOverflow_match {
2690+
LHS_t L;
2691+
RHS_t R;
2692+
Diff_t S;
2693+
2694+
USubWithOverflow_match(const LHS_t &L, const RHS_t &R, const Diff_t &S)
2695+
: L(L), R(R), S(S) {}
2696+
2697+
template <typename OpTy> bool match(OpTy *V) const {
2698+
Value *ICmpLHS = nullptr, *ICmpRHS = nullptr;
2699+
CmpPredicate Pred;
2700+
if (!m_ICmp(Pred, m_Value(ICmpLHS), m_Value(ICmpRHS)).match(V))
2701+
return false;
2702+
2703+
Value *SubLHS = nullptr, *SubRHS = nullptr;
2704+
auto SubExpr = m_Sub(m_Value(SubLHS), m_Value(SubRHS));
2705+
2706+
Value *AddLHS = nullptr, *AddRHS = nullptr;
2707+
auto AddExpr = m_Add(m_Value(AddLHS), m_Value(AddRHS));
2708+
2709+
// (a - b) >u a OR (a + (-c)) >u a (allow add-canonicalized forms
2710+
// but only where the RHS is a constant APInt that is negative)
2711+
if (Pred == ICmpInst::ICMP_UGT) {
2712+
if (SubExpr.match(ICmpLHS) && ICmpRHS == SubLHS)
2713+
return L.match(SubLHS) && R.match(SubRHS) && S.match(ICmpLHS);
2714+
2715+
if (AddExpr.match(ICmpLHS)) {
2716+
const APInt *AddC = nullptr;
2717+
if (m_APInt(AddC).match(AddRHS) && ICmpRHS == AddLHS) {
2718+
APInt NegC = -(*AddC);
2719+
Constant *NegConst = ConstantInt::get(AddRHS->getType(), NegC);
2720+
return L.match(AddLHS) && R.match(NegConst) && S.match(ICmpLHS);
2721+
}
2722+
}
2723+
}
2724+
2725+
// a <u (a - b) OR a <u (a + (-c))
2726+
if (Pred == ICmpInst::ICMP_ULT) {
2727+
if (SubExpr.match(ICmpRHS) && ICmpLHS == SubLHS)
2728+
return L.match(SubLHS) && R.match(SubRHS) && S.match(ICmpRHS);
2729+
2730+
if (AddExpr.match(ICmpRHS)) {
2731+
const APInt *AddC = nullptr;
2732+
if (m_APInt(AddC).match(AddRHS) && ICmpLHS == AddLHS) {
2733+
APInt NegC = -(*AddC);
2734+
Constant *NegConst = ConstantInt::get(AddRHS->getType(), NegC);
2735+
return L.match(AddLHS) && R.match(NegConst) && S.match(ICmpRHS);
2736+
}
2737+
}
2738+
}
2739+
2740+
// Special-case for 0 - a != 0 (common canonicalization)
2741+
if (Pred == ICmpInst::ICMP_NE) {
2742+
// (0 - a) != 0
2743+
if (SubExpr.match(ICmpLHS) && m_Zero().match(ICmpRHS) &&
2744+
m_Zero().match(SubLHS))
2745+
return L.match(SubLHS) && R.match(SubRHS) && S.match(ICmpLHS);
2746+
2747+
// 0 != (0 - a)
2748+
if (m_Zero().match(ICmpLHS) && SubExpr.match(ICmpRHS) &&
2749+
m_Zero().match(SubLHS))
2750+
return L.match(SubLHS) && R.match(SubRHS) && S.match(ICmpRHS);
2751+
}
2752+
2753+
return false;
2754+
}
2755+
};
2756+
2757+
template <typename LHS_t, typename RHS_t, typename Diff_t>
2758+
USubWithOverflow_match<LHS_t, RHS_t, Diff_t>
2759+
m_USubWithOverflow(const LHS_t &L, const RHS_t &R, const Diff_t &S) {
2760+
return USubWithOverflow_match<LHS_t, RHS_t, Diff_t>(L, R, S);
2761+
}
2762+
26882763
template <typename Opnd_t> struct Argument_match {
26892764
unsigned OpI;
26902765
Opnd_t Val;

llvm/lib/CodeGen/CodeGenPrepare.cpp

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1695,19 +1695,23 @@ bool CodeGenPrepare::combineToUAddWithOverflow(CmpInst *Cmp,
16951695
return true;
16961696
}
16971697

1698-
bool CodeGenPrepare::combineToUSubWithOverflow(CmpInst *Cmp,
1699-
ModifyDT &ModifiedDT) {
1700-
// We are not expecting non-canonical/degenerate code. Just bail out.
1698+
static bool matchUSubWithOverflowConstantEdgeCases(CmpInst *Cmp,
1699+
BinaryOperator *&Sub) {
1700+
// A - B, A u> B --> usubo(A, B)
17011701
Value *A = Cmp->getOperand(0), *B = Cmp->getOperand(1);
1702+
1703+
// We are not expecting non-canonical/degenerate code. Just bail out.
17021704
if (isa<Constant>(A) && isa<Constant>(B))
17031705
return false;
17041706

1705-
// Convert (A u> B) to (A u< B) to simplify pattern matching.
17061707
ICmpInst::Predicate Pred = Cmp->getPredicate();
1708+
1709+
// Normalize: convert (A u> B) -> (B u< A)
17071710
if (Pred == ICmpInst::ICMP_UGT) {
17081711
std::swap(A, B);
17091712
Pred = ICmpInst::ICMP_ULT;
17101713
}
1714+
17111715
// Convert special-case: (A == 0) is the same as (A u< 1).
17121716
if (Pred == ICmpInst::ICMP_EQ && match(B, m_ZeroInt())) {
17131717
B = ConstantInt::get(B->getType(), 1);
@@ -1718,39 +1722,65 @@ bool CodeGenPrepare::combineToUSubWithOverflow(CmpInst *Cmp,
17181722
std::swap(A, B);
17191723
Pred = ICmpInst::ICMP_ULT;
17201724
}
1725+
17211726
if (Pred != ICmpInst::ICMP_ULT)
17221727
return false;
17231728

1724-
// Walk the users of a variable operand of a compare looking for a subtract or
1725-
// add with that same operand. Also match the 2nd operand of the compare to
1726-
// the add/sub, but that may be a negated constant operand of an add.
1729+
// Walk the users of the variable operand of the compare looking for a
1730+
// subtract or add with that same operand. Also match the 2nd operand of the
1731+
// compare to the add/sub, but that may be a negated constant operand of an
1732+
// add.
17271733
Value *CmpVariableOperand = isa<Constant>(A) ? B : A;
1728-
BinaryOperator *Sub = nullptr;
1734+
Sub = nullptr;
1735+
17291736
for (User *U : CmpVariableOperand->users()) {
17301737
// A - B, A u< B --> usubo(A, B)
17311738
if (match(U, m_Sub(m_Specific(A), m_Specific(B)))) {
17321739
Sub = cast<BinaryOperator>(U);
1733-
break;
1740+
return true;
17341741
}
17351742

17361743
// A + (-C), A u< C (canonicalized form of (sub A, C))
17371744
const APInt *CmpC, *AddC;
17381745
if (match(U, m_Add(m_Specific(A), m_APInt(AddC))) &&
17391746
match(B, m_APInt(CmpC)) && *AddC == -(*CmpC)) {
17401747
Sub = cast<BinaryOperator>(U);
1741-
break;
1748+
return true;
17421749
}
17431750
}
1744-
if (!Sub)
1745-
return false;
17461751

1752+
return false;
1753+
}
1754+
1755+
bool CodeGenPrepare::combineToUSubWithOverflow(CmpInst *Cmp,
1756+
ModifyDT &ModifiedDT) {
1757+
bool EdgeCase = false;
1758+
Value *A = nullptr, *B = nullptr;
1759+
BinaryOperator *Sub = nullptr;
1760+
1761+
// If the compare already matches the (sub, icmp) pattern use it directly.
1762+
if (!match(Cmp, m_USubWithOverflow(m_Value(A), m_Value(B), m_BinOp(Sub)))) {
1763+
// Otherwise try to recognize constant-edge-case forms like
1764+
// icmp ne (sub 0, B), 0 or
1765+
// icmp eq (sub A, 1), 0
1766+
if (!matchUSubWithOverflowConstantEdgeCases(Cmp, Sub))
1767+
return false;
1768+
// Set A/B from the discovered Sub and record that this was an edge-case
1769+
// match.
1770+
A = Sub->getOperand(0);
1771+
B = Sub->getOperand(1);
1772+
EdgeCase = true;
1773+
}
1774+
1775+
// Check target wants the overflow intrinsic formed. When matching an
1776+
// edge-case we allow forming the intrinsic with fewer uses.
17471777
if (!TLI->shouldFormOverflowOp(ISD::USUBO,
17481778
TLI->getValueType(*DL, Sub->getType()),
1749-
Sub->hasNUsesOrMore(1)))
1779+
Sub->hasNUsesOrMore(EdgeCase ? 1 : 2)))
17501780
return false;
17511781

1752-
if (!replaceMathCmpWithIntrinsic(Sub, Sub->getOperand(0), Sub->getOperand(1),
1753-
Cmp, Intrinsic::usub_with_overflow))
1782+
if (!replaceMathCmpWithIntrinsic(Sub, A, B, Cmp,
1783+
Intrinsic::usub_with_overflow))
17541784
return false;
17551785

17561786
// Reset callers - do not crash by iterating over a dead instruction.

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7829,6 +7829,23 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
78297829
}
78307830
}
78317831

7832+
Instruction *SubI = nullptr;
7833+
if (match(&I, m_USubWithOverflow(m_Value(X), m_Value(Y),
7834+
m_Instruction(SubI))) &&
7835+
isa<IntegerType>(X->getType())) {
7836+
Value *Result;
7837+
Constant *Overflow;
7838+
// m_UAddWithOverflow can match patterns that do not include an explicit
7839+
// "add" instruction, so check the opcode of the matched op.
7840+
if (SubI->getOpcode() == Instruction::Sub &&
7841+
OptimizeOverflowCheck(Instruction::Sub, /*Signed*/ false, X, Y, *SubI,
7842+
Result, Overflow)) {
7843+
replaceInstUsesWith(*SubI, Result);
7844+
eraseInstFromFunction(*SubI);
7845+
return replaceInstUsesWith(I, Overflow);
7846+
}
7847+
}
7848+
78327849
// (zext X) * (zext Y) --> llvm.umul.with.overflow.
78337850
if (match(Op0, m_NUWMul(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) &&
78347851
match(Op1, m_APInt(C))) {

0 commit comments

Comments
 (0)