Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions llvm/include/llvm/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -1430,6 +1430,39 @@ m_NUWAddLike(const LHS &L, const RHS &R) {
return m_CombineOr(m_NUWAdd(L, R), m_DisjointOr(L, R));
}

template <typename LHS, typename RHS, bool Commutable = false>
struct XorLike_match {
LHS L;
RHS R;

XorLike_match(const LHS &L, const RHS &R) : L(L), R(R) {}

template <typename OpTy> bool match(OpTy *V) {
if (auto *Op = dyn_cast<BinaryOperator>(V)) {
bool CheckCommuted = Commutable;
if (Op->getOpcode() == Instruction::Sub && Op->hasNoUnsignedWrap() &&
PatternMatch::match(Op->getOperand(0), m_LowBitMask()))
CheckCommuted = false;
else if (Op->getOpcode() != Instruction::Xor)
return false;
return (L.match(Op->getOperand(0)) && R.match(Op->getOperand(1))) ||
(CheckCommuted && L.match(Op->getOperand(1)) &&
R.match(Op->getOperand(0)));
}
return false;
}
};

template <typename LHS, typename RHS>
inline auto m_XorLike(const LHS &L, const RHS &R) {
return XorLike_match<LHS, RHS>(L, R);
}

template <typename LHS, typename RHS>
inline auto m_c_XorLike(const LHS &L, const RHS &R) {
return XorLike_match<LHS, RHS, /*Commutable=*/true>(L, R);
}

//===----------------------------------------------------------------------===//
// Class that matches a group of binary opcodes.
//
Expand Down
30 changes: 15 additions & 15 deletions llvm/lib/Analysis/InstructionSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ static Value *simplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,
// The no-wrapping add guarantees that the top bit will be set by the add.
// Therefore, the xor must be clearing the already set sign bit of Y.
if ((IsNSW || IsNUW) && match(Op1, m_SignMask()) &&
match(Op0, m_Xor(m_Value(Y), m_SignMask())))
match(Op0, m_XorLike(m_Value(Y), m_SignMask())))
return Y;

// add nuw %x, -1 -> -1, because %x can only be 0.
Expand Down Expand Up @@ -2154,17 +2154,17 @@ static Value *simplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
// ((X | Y) ^ X ) & ((X | Y) ^ Y) --> 0
// ((X | Y) ^ Y ) & ((X | Y) ^ X) --> 0
BinaryOperator *Or;
if (match(Op0, m_c_Xor(m_Value(X),
if (match(Op0, m_c_XorLike(m_Value(X),
m_CombineAnd(m_BinOp(Or),
m_c_Or(m_Deferred(X), m_Value(Y))))) &&
match(Op1, m_c_Xor(m_Specific(Or), m_Specific(Y))))
match(Op1, m_c_XorLike(m_Specific(Or), m_Specific(Y))))
return Constant::getNullValue(Op0->getType());

const APInt *C1;
Value *A;
// (A ^ C) & (A ^ ~C) -> 0
if (match(Op0, m_Xor(m_Value(A), m_APInt(C1))) &&
match(Op1, m_Xor(m_Specific(A), m_SpecificInt(~*C1))))
if (match(Op0, m_XorLike(m_Value(A), m_APInt(C1))) &&
match(Op1, m_XorLike(m_Specific(A), m_SpecificInt(~*C1))))
return Constant::getNullValue(Op0->getType());

if (Op0->getType()->isIntOrIntVectorTy(1)) {
Expand Down Expand Up @@ -2217,13 +2217,13 @@ static Value *simplifyOrLogic(Value *X, Value *Y) {

// (A ^ B) | (A | B) --> A | B
// (A ^ B) | (B | A) --> B | A
if (match(X, m_Xor(m_Value(A), m_Value(B))) &&
if (match(X, m_XorLike(m_Value(A), m_Value(B))) &&
match(Y, m_c_Or(m_Specific(A), m_Specific(B))))
return Y;

// ~(A ^ B) | (A | B) --> -1
// ~(A ^ B) | (B | A) --> -1
if (match(X, m_Not(m_Xor(m_Value(A), m_Value(B)))) &&
if (match(X, m_Not(m_XorLike(m_Value(A), m_Value(B)))) &&
match(Y, m_c_Or(m_Specific(A), m_Specific(B))))
return ConstantInt::getAllOnesValue(Ty);

Expand All @@ -2232,14 +2232,14 @@ static Value *simplifyOrLogic(Value *X, Value *Y) {
// (A & ~B) | (B ^ A) --> B ^ A
// (~B & A) | (B ^ A) --> B ^ A
if (match(X, m_c_And(m_Value(A), m_Not(m_Value(B)))) &&
match(Y, m_c_Xor(m_Specific(A), m_Specific(B))))
match(Y, m_c_XorLike(m_Specific(A), m_Specific(B))))
return Y;

// (~A ^ B) | (A & B) --> ~A ^ B
// (B ^ ~A) | (A & B) --> B ^ ~A
// (~A ^ B) | (B & A) --> ~A ^ B
// (B ^ ~A) | (B & A) --> B ^ ~A
if (match(X, m_c_Xor(m_Not(m_Value(A)), m_Value(B))) &&
if (match(X, m_c_XorLike(m_Not(m_Value(A)), m_Value(B))) &&
match(Y, m_c_And(m_Specific(A), m_Specific(B))))
return X;

Expand All @@ -2248,7 +2248,7 @@ static Value *simplifyOrLogic(Value *X, Value *Y) {
// (B | ~A) | (A ^ B) --> -1
// (B | ~A) | (B ^ A) --> -1
if (match(X, m_c_Or(m_Not(m_Value(A)), m_Value(B))) &&
match(Y, m_c_Xor(m_Specific(A), m_Specific(B))))
match(Y, m_c_XorLike(m_Specific(A), m_Specific(B))))
return ConstantInt::getAllOnesValue(Ty);

// (~A & B) | ~(A | B) --> ~A
Expand All @@ -2271,7 +2271,7 @@ static Value *simplifyOrLogic(Value *X, Value *Y) {
// ~(A ^ B) | (A & B) --> ~(A ^ B)
// ~(A ^ B) | (B & A) --> ~(A ^ B)
Value *NotAB;
if (match(X, m_CombineAnd(m_Not(m_Xor(m_Value(A), m_Value(B))),
if (match(X, m_CombineAnd(m_Not(m_XorLike(m_Value(A), m_Value(B))),
m_Value(NotAB))) &&
match(Y, m_c_And(m_Specific(A), m_Specific(B))))
return NotAB;
Expand All @@ -2280,7 +2280,7 @@ static Value *simplifyOrLogic(Value *X, Value *Y) {
// ~(A & B) | (B ^ A) --> ~(A & B)
if (match(X, m_CombineAnd(m_Not(m_And(m_Value(A), m_Value(B))),
m_Value(NotAB))) &&
match(Y, m_c_Xor(m_Specific(A), m_Specific(B))))
match(Y, m_c_XorLike(m_Specific(A), m_Specific(B))))
return NotAB;

return nullptr;
Expand Down Expand Up @@ -2435,8 +2435,8 @@ static Value *simplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
return V;

// (A ^ C) | (A ^ ~C) -> -1, i.e. all bits set to one.
if (match(Op0, m_Xor(m_Value(A), m_APInt(C1))) &&
match(Op1, m_Xor(m_Specific(A), m_SpecificInt(~*C1))))
if (match(Op0, m_XorLike(m_Value(A), m_APInt(C1))) &&
match(Op1, m_XorLike(m_Specific(A), m_SpecificInt(~*C1))))
return Constant::getAllOnesValue(Op0->getType());

if (Op0->getType()->isIntOrIntVectorTy(1)) {
Expand Down Expand Up @@ -5097,7 +5097,7 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr,
}
// gep (gep V, C), (xor V, -1) -> C-1
if (match(Indices.back(),
m_Xor(m_PtrToInt(m_Specific(StrippedBasePtr)), m_AllOnes())) &&
m_XorLike(m_PtrToInt(m_Specific(StrippedBasePtr)), m_AllOnes())) &&
!BasePtrOffset.isOne()) {
auto *CI = ConstantInt::get(GEPTy->getContext(), BasePtrOffset - 1);
return ConstantExpr::getIntToPtr(CI, GEPTy);
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ static bool haveNoCommonBitsSetSpecialCases(const Value *LHS, const Value *RHS,
// for constant Y.
Value *Y;
if (match(RHS,
m_c_Xor(m_c_And(m_Specific(LHS), m_Value(Y)), m_Deferred(Y))) &&
m_c_XorLike(m_c_And(m_Specific(LHS), m_Value(Y)), m_Deferred(Y))) &&
isGuaranteedNotToBeUndef(LHS, SQ.AC, SQ.CxtI, SQ.DT) &&
isGuaranteedNotToBeUndef(Y, SQ.AC, SQ.CxtI, SQ.DT))
return true;
Expand Down Expand Up @@ -690,7 +690,7 @@ static void computeKnownBitsFromCmp(const Value *V, CmpInst::Predicate Pred,
if (match(Y, m_APInt(Mask)))
Known.One |= *C & ~*Mask;
// assume(V ^ Mask = C)
} else if (match(LHS, m_Xor(m_V, m_APInt(Mask))) &&
} else if (match(LHS, m_XorLike(m_V, m_APInt(Mask))) &&
match(RHS, m_APInt(C))) {
// Equivalent to assume(V == Mask ^ C)
Known = Known.unionWith(KnownBits::makeConstant(*C ^ *Mask));
Expand Down Expand Up @@ -954,7 +954,7 @@ getKnownBitsFromAndXorOr(const Operator *I, const APInt &DemandedElts,
// Demanded) == (xor(x, x-1) & Demanded). Extend the xor pattern
// to use arbitrary C if xor(x, x-C) as the same as xor(x, x-1).
if (HasKnownOne &&
match(I, m_c_Xor(m_Value(X), m_Add(m_Deferred(X), m_AllOnes())))) {
match(I, m_c_XorLike(m_Value(X), m_Add(m_Deferred(X), m_AllOnes())))) {
const KnownBits &XBits = I->getOperand(0) == X ? KnownLHS : KnownRHS;
KnownOut = XBits.blsmsk();
}
Expand Down
34 changes: 12 additions & 22 deletions llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -767,10 +767,10 @@ static Value *checkForNegativeOperand(BinaryOperator &I,

if (match(LHS, m_Add(m_Value(X), m_One()))) {
// if XOR on other side, swap
if (match(RHS, m_Xor(m_Value(Y), m_APInt(C1))))
if (match(RHS, m_XorLike(m_Value(Y), m_APInt(C1))))
std::swap(X, RHS);

if (match(X, m_Xor(m_Value(Y), m_APInt(C1)))) {
if (match(X, m_XorLike(m_Value(Y), m_APInt(C1)))) {
// X = XOR(Y, C1), Y = OR(Z, C2), C2 = NOT(C1) ==> X == NOT(AND(Z, C1))
// ADD(ADD(X, 1), RHS) == ADD(X, ADD(RHS, 1)) == SUB(RHS, AND(Z, C1))
if (match(Y, m_Or(m_Value(Z), m_APInt(C2))) && (*C2 == ~(*C1))) {
Expand All @@ -790,13 +790,13 @@ static Value *checkForNegativeOperand(BinaryOperator &I,
RHS = I.getOperand(1);

// if XOR is on other side, swap
if (match(RHS, m_Xor(m_Value(Y), m_APInt(C1))))
if (match(RHS, m_XorLike(m_Value(Y), m_APInt(C1))))
std::swap(LHS, RHS);

// C2 is ODD
// LHS = XOR(Y, C1), Y = AND(Z, C2), C1 == (C2 + 1) => LHS == NEG(OR(Z, ~C2))
// ADD(LHS, RHS) == SUB(RHS, OR(Z, ~C2))
if (match(LHS, m_Xor(m_Value(Y), m_APInt(C1))))
if (match(LHS, m_XorLike(m_Value(Y), m_APInt(C1))))
if (C1->countr_zero() == 0)
if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && *C1 == (*C2 + 1)) {
Value *NewOr = Builder.CreateOr(Z, ~(*C2));
Expand Down Expand Up @@ -937,11 +937,11 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {

// Is this add the last step in a convoluted sext?
// add(zext(xor i16 X, -32768), -32768) --> sext X
if (match(Op0, m_ZExt(m_Xor(m_Value(X), m_APInt(C2)))) &&
if (match(Op0, m_ZExt(m_XorLike(m_Value(X), m_APInt(C2)))) &&
C2->isMinSignedValue() && C2->sext(Ty->getScalarSizeInBits()) == *C)
return CastInst::Create(Instruction::SExt, X, Ty);

if (match(Op0, m_Xor(m_Value(X), m_APInt(C2)))) {
if (match(Op0, m_XorLike(m_Value(X), m_APInt(C2)))) {
// (X ^ signmask) + C --> (X + (signmask ^ C))
if (C2->isSignMask())
return BinaryOperator::CreateAdd(X, ConstantInt::get(Ty, *C2 ^ *C));
Expand Down Expand Up @@ -1685,7 +1685,7 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {

// (add (xor A, B) (and A, B)) --> (or A, B)
// (add (and A, B) (xor A, B)) --> (or A, B)
if (match(&I, m_c_BinOp(m_Xor(m_Value(A), m_Value(B)),
if (match(&I, m_c_BinOp(m_XorLike(m_Value(A), m_Value(B)),
m_c_And(m_Deferred(A), m_Deferred(B)))))
return BinaryOperator::CreateOr(A, B);

Expand Down Expand Up @@ -1848,7 +1848,7 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
m_c_Add(
m_ZExt(m_ICmp(Pred, m_Intrinsic<Intrinsic::ctpop>(m_Value(A)),
m_One())),
m_OneUse(m_ZExtOrSelf(m_OneUse(m_Xor(
m_OneUse(m_ZExtOrSelf(m_OneUse(m_XorLike(
m_OneUse(m_TruncOrSelf(m_OneUse(
m_Intrinsic<Intrinsic::ctlz>(m_Deferred(A), m_One())))),
m_APInt(XorC))))))) &&
Expand Down Expand Up @@ -2424,16 +2424,6 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {

const APInt *Op0C;
if (match(Op0, m_APInt(Op0C))) {
if (Op0C->isMask()) {
// Turn this into a xor if LHS is 2^n-1 and the remaining bits are known
// zero. We don't use information from dominating conditions so this
// transform is easier to reverse if necessary.
KnownBits RHSKnown = llvm::computeKnownBits(
Op1, 0, SQ.getWithInstruction(&I).getWithoutDomCondCache());
if ((*Op0C | RHSKnown.Zero).isAllOnes())
return BinaryOperator::CreateXor(Op1, Op0);
}

// C - ((C3 -nuw X) & C2) --> (C - (C2 & C3)) + (X & C2) when:
// (C3 - ((C2 & C3) - 1)) is pow2
// ((C2 + C3) & ((C2 & C3) - 1)) == ((C2 & C3) - 1)
Expand Down Expand Up @@ -2502,15 +2492,15 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
// (sub (or A, B), (xor A, B)) --> (and A, B)
{
Value *A, *B;
if (match(Op1, m_Xor(m_Value(A), m_Value(B))) &&
if (match(Op1, m_XorLike(m_Value(A), m_Value(B))) &&
match(Op0, m_c_Or(m_Specific(A), m_Specific(B))))
return BinaryOperator::CreateAnd(A, B);
}

// (sub (xor A, B) (or A, B)) --> neg (and A, B)
{
Value *A, *B;
if (match(Op0, m_Xor(m_Value(A), m_Value(B))) &&
if (match(Op0, m_XorLike(m_Value(A), m_Value(B))) &&
match(Op1, m_c_Or(m_Specific(A), m_Specific(B))) &&
(Op0->hasOneUse() || Op1->hasOneUse()))
return BinaryOperator::CreateNeg(Builder.CreateAnd(A, B));
Expand Down Expand Up @@ -2548,7 +2538,7 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
// (sub (sext C), (xor X, (sext C))) => (select C, X, (neg X))
Value *C, *X;
auto m_SubXorCmp = [&C, &X](Value *LHS, Value *RHS) {
return match(LHS, m_OneUse(m_c_Xor(m_Value(X), m_Specific(RHS)))) &&
return match(LHS, m_OneUse(m_c_XorLike(m_Value(X), m_Specific(RHS)))) &&
match(RHS, m_SExt(m_Value(C))) &&
(C->getType()->getScalarSizeInBits() == 1);
};
Expand Down Expand Up @@ -2683,7 +2673,7 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
unsigned BitWidth = Ty->getScalarSizeInBits();
if (match(Op1, m_AShr(m_Value(A), m_APInt(ShAmt))) &&
Op1->hasNUses(2) && *ShAmt == BitWidth - 1 &&
match(Op0, m_OneUse(m_c_Xor(m_Specific(A), m_Specific(Op1))))) {
match(Op0, m_OneUse(m_c_XorLike(m_Specific(A), m_Specific(Op1))))) {
// B = ashr i32 A, 31 ; smear the sign bit
// sub (xor A, B), B ; flip bits if negative and subtract -1 (add 1)
// --> (A < 0) ? -A : A
Expand Down
Loading
Loading