Skip to content
85 changes: 63 additions & 22 deletions llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3602,6 +3602,11 @@ struct DecomposedBitMaskMul {
APInt Mask;
bool NUW;
bool NSW;

bool isCombineableWith(const DecomposedBitMaskMul Other) {
return X == Other.X && !Mask.intersects(Other.Mask) &&
Factor == Other.Factor;
}
};

static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
Expand Down Expand Up @@ -3659,6 +3664,59 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
return std::nullopt;
}

/// (A & N) * C + (A & M) * C -> (A & (N + M)) & C
/// This also accepts the equivalent select form of (A & N) * C
/// expressions i.e. !(A & N) ? 0 : N * C)
static Value *foldBitmaskMul(Value *Op0, Value *Op1,
InstCombiner::BuilderTy &Builder) {
auto Decomp1 = matchBitmaskMul(Op1);
if (!Decomp1)
return nullptr;

auto Decomp0 = matchBitmaskMul(Op0);
if (!Decomp0)
return nullptr;

if (Decomp0->isCombineableWith(*Decomp1)) {
Value *NewAnd = Builder.CreateAnd(
Decomp0->X,
ConstantInt::get(Decomp0->X->getType(), Decomp0->Mask + Decomp1->Mask));

return Builder.CreateMul(
NewAnd, ConstantInt::get(NewAnd->getType(), Decomp1->Factor), "",
Decomp0->NUW && Decomp1->NUW, Decomp0->NSW && Decomp1->NSW);
}

return nullptr;
}

Value *InstCombinerImpl::foldDisjointOr(Value *LHS, Value *RHS) {
if (Value *Res = foldBitmaskMul(LHS, RHS, Builder))
return Res;

return nullptr;
}

Value *InstCombinerImpl::reassociateDisjointOr(Value *LHS, Value *RHS) {

Value *X, *Y;
if (match(RHS, m_OneUse(m_DisjointOr(m_Value(X), m_Value(Y))))) {
if (Value *Res = foldDisjointOr(LHS, X))
return Builder.CreateOr(Res, Y, "", /*IsDisjoint=*/true);
if (Value *Res = foldDisjointOr(LHS, Y))
return Builder.CreateOr(Res, X, "", /*IsDisjoint=*/true);
}

if (match(LHS, m_OneUse(m_DisjointOr(m_Value(X), m_Value(Y))))) {
if (Value *Res = foldDisjointOr(X, RHS))
return Builder.CreateOr(Res, Y, "", /*IsDisjoint=*/true);
if (Value *Res = foldDisjointOr(Y, RHS))
return Builder.CreateOr(Res, X, "", /*IsDisjoint=*/true);
}

return nullptr;
}

// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
// here. We should standardize that construct where it is needed or choose some
// other way to ensure that commutated variants of patterns are not missed.
Expand Down Expand Up @@ -3741,28 +3799,11 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
/*NSW=*/true, /*NUW=*/true))
return R;

// (A & N) * C + (A & M) * C -> (A & (N + M)) & C
// This also accepts the equivalent select form of (A & N) * C
// expressions i.e. !(A & N) ? 0 : N * C)
auto Decomp1 = matchBitmaskMul(I.getOperand(1));
if (Decomp1) {
auto Decomp0 = matchBitmaskMul(I.getOperand(0));
if (Decomp0 && Decomp0->X == Decomp1->X &&
(Decomp0->Mask & Decomp1->Mask).isZero() &&
Decomp0->Factor == Decomp1->Factor) {

Value *NewAnd = Builder.CreateAnd(
Decomp0->X, ConstantInt::get(Decomp0->X->getType(),
(Decomp0->Mask + Decomp1->Mask)));

auto *Combined = BinaryOperator::CreateMul(
NewAnd, ConstantInt::get(NewAnd->getType(), Decomp1->Factor));

Combined->setHasNoUnsignedWrap(Decomp0->NUW && Decomp1->NUW);
Combined->setHasNoSignedWrap(Decomp0->NSW && Decomp1->NSW);
return Combined;
}
}
if (Value *Res = foldBitmaskMul(I.getOperand(0), I.getOperand(1), Builder))
return replaceInstUsesWith(I, Res);

if (Value *Res = reassociateDisjointOr(I.getOperand(0), I.getOperand(1)))
return replaceInstUsesWith(I, Res);
}

Value *X, *Y;
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
Value *reassociateBooleanAndOr(Value *LHS, Value *X, Value *Y, Instruction &I,
bool IsAnd, bool RHSIsLogical);

Value *foldDisjointOr(Value *LHS, Value *RHS);

Value *reassociateDisjointOr(Value *LHS, Value *RHS);

Instruction *
canonicalizeConditionalNegationViaMathToSelect(BinaryOperator &i);

Expand Down
Loading