Skip to content
89 changes: 68 additions & 21 deletions llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2377,6 +2377,7 @@ Value *InstCombinerImpl::reassociateBooleanAndOr(Value *LHS, Value *X, Value *Y,
if (Value *Res = foldBooleanAndOr(LHS, Y, I, IsAnd, /*IsLogical=*/false))
return RHSIsLogical ? Builder.CreateLogicalOp(Opcode, X, Res)
: Builder.CreateBinOp(Opcode, X, Res);

return nullptr;
}

Expand Down Expand Up @@ -3602,6 +3603,11 @@ struct DecomposedBitMaskMul {
APInt Mask;
bool NUW;
bool NSW;

bool isCombineableWith(DecomposedBitMaskMul Other) {
return X == Other.X && (Mask & Other.Mask).isZero() &&
Factor == Other.Factor;
}
};

static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
Expand Down Expand Up @@ -3659,6 +3665,64 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
return std::nullopt;
}

/// (A & N) * C + (A & M) * C -> (A & (N + M)) & C
/// This also accepts the equivalent select form of (A & N) * C
/// expressions i.e. !(A & N) ? 0 : N * C)
static Value *foldBitmaskMul(Value *Op0, Value *Op1,
InstCombiner::BuilderTy &Builder) {
auto Decomp1 = matchBitmaskMul(Op1);

if (Decomp1) {
auto Decomp0 = matchBitmaskMul(Op0);

if (Decomp0) {
// If we have independent operands in the BitmaskMul chain, then just
// reassociate to encourage combining in future iterations.

if (Decomp0->isCombineableWith(*Decomp1)) {
auto NewAnd = Builder.CreateAnd(
Decomp0->X, ConstantInt::get(Decomp0->X->getType(),
(Decomp0->Mask + Decomp1->Mask)));

auto Res = Builder.CreateMul(
NewAnd, ConstantInt::get(NewAnd->getType(), Decomp1->Factor), "",
Decomp0->NUW && Decomp1->NUW, Decomp0->NSW && Decomp1->NSW);
return Res;
}
}
}

return nullptr;
}

Value *InstCombinerImpl::foldDisjointOr(Value *LHS, Value *RHS) {
if (Value *Res = foldBitmaskMul(LHS, RHS, Builder)) {
return Res;
}

return nullptr;
}

Value *InstCombinerImpl::reassociateDisjointOr(Value *LHS, Value *RHS) {

Value *X, *Y;
if (match(RHS, m_OneUse(m_DisjointOr(m_Value(X), m_Value(Y))))) {
if (Value *Res = foldDisjointOr(LHS, X))
return Builder.CreateOr(Res, Y, "", /*IsDisjoint=*/true);
if (Value *Res = foldDisjointOr(LHS, Y))
return Builder.CreateOr(Res, X, "", /*IsDisjoint=*/true);
}

if (match(LHS, m_OneUse(m_DisjointOr(m_Value(X), m_Value(Y))))) {
if (Value *Res = foldDisjointOr(X, RHS))
return Builder.CreateOr(Res, Y, "", /*IsDisjoint=*/true);
if (Value *Res = foldDisjointOr(Y, RHS))
return Builder.CreateOr(Res, X, "", /*IsDisjoint=*/true);
}

return nullptr;
}

// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
// here. We should standardize that construct where it is needed or choose some
// other way to ensure that commutated variants of patterns are not missed.
Expand Down Expand Up @@ -3741,28 +3805,11 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
/*NSW=*/true, /*NUW=*/true))
return R;

// (A & N) * C + (A & M) * C -> (A & (N + M)) & C
// This also accepts the equivalent select form of (A & N) * C
// expressions i.e. !(A & N) ? 0 : N * C)
auto Decomp1 = matchBitmaskMul(I.getOperand(1));
if (Decomp1) {
auto Decomp0 = matchBitmaskMul(I.getOperand(0));
if (Decomp0 && Decomp0->X == Decomp1->X &&
(Decomp0->Mask & Decomp1->Mask).isZero() &&
Decomp0->Factor == Decomp1->Factor) {

Value *NewAnd = Builder.CreateAnd(
Decomp0->X, ConstantInt::get(Decomp0->X->getType(),
(Decomp0->Mask + Decomp1->Mask)));

auto *Combined = BinaryOperator::CreateMul(
NewAnd, ConstantInt::get(NewAnd->getType(), Decomp1->Factor));
if (Value *Res = foldBitmaskMul(I.getOperand(0), I.getOperand(1), Builder))
return replaceInstUsesWith(I, Res);

Combined->setHasNoUnsignedWrap(Decomp0->NUW && Decomp1->NUW);
Combined->setHasNoSignedWrap(Decomp0->NSW && Decomp1->NSW);
return Combined;
}
}
if (Value *Res = reassociateDisjointOr(I.getOperand(0), I.getOperand(1)))
return replaceInstUsesWith(I, Res);
}

Value *X, *Y;
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
Value *reassociateBooleanAndOr(Value *LHS, Value *X, Value *Y, Instruction &I,
bool IsAnd, bool RHSIsLogical);

Value *foldDisjointOr(Value *LHS, Value *RHS);

Value *reassociateDisjointOr(Value *LHS, Value *RHS);

Instruction *
canonicalizeConditionalNegationViaMathToSelect(BinaryOperator &i);

Expand Down
244 changes: 244 additions & 0 deletions llvm/test/Transforms/InstCombine/or-bitmask.ll
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,250 @@ define i32 @and_mul_non_disjoint(i32 %in) {
ret i32 %out
}

define i32 @unrelated_ops(i32 %in, i32 %in2) {
; CHECK-LABEL: @unrelated_ops(
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 15
; CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i32 [[TMP1]], 72
; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TMP2]], [[IN2:%.*]]
; CHECK-NEXT: ret i32 [[OUT]]
;
%and0 = and i32 %in, 3
%temp = mul nuw nsw i32 %and0, 72
%and1 = and i32 %in, 12
%temp2 = mul nuw nsw i32 %and1, 72
%temp3 = or disjoint i32 %in2, %temp2
%out = or disjoint i32 %temp, %temp3
ret i32 %out
}

define i32 @unrelated_ops1(i32 %in, i32 %in2) {
; CHECK-LABEL: @unrelated_ops1(
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 15
; CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i32 [[TMP1]], 72
; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TMP2]], [[IN2:%.*]]
; CHECK-NEXT: ret i32 [[OUT]]
;
%and0 = and i32 %in, 3
%temp = mul nuw nsw i32 %and0, 72
%and1 = and i32 %in, 12
%temp2 = mul nuw nsw i32 %and1, 72
%temp3 = or disjoint i32 %in2, %temp
%out = or disjoint i32 %temp3, %temp2
ret i32 %out
}

define i32 @unrelated_ops2(i32 %in, i32 %in2, i32 %in3) {
; CHECK-LABEL: @unrelated_ops2(
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
; CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i32 [[TMP1]], 72
; CHECK-NEXT: [[TEMP3:%.*]] = or disjoint i32 [[TMP2]], [[IN4:%.*]]
; CHECK-NEXT: [[AND1:%.*]] = and i32 [[IN]], 12
; CHECK-NEXT: [[IN2:%.*]] = mul nuw nsw i32 [[AND1]], 72
; CHECK-NEXT: [[TMP3:%.*]] = or disjoint i32 [[IN3:%.*]], [[IN2]]
; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TEMP3]], [[TMP3]]
; CHECK-NEXT: ret i32 [[OUT]]
;
%and0 = and i32 %in, 3
%temp = mul nuw nsw i32 %and0, 72
%temp3 = or disjoint i32 %temp, %in3
%and1 = and i32 %in, 12
%temp2 = mul nuw nsw i32 %and1, 72
%temp4 = or disjoint i32 %in2, %temp2
%out = or disjoint i32 %temp3, %temp4
ret i32 %out
}

define i32 @unrelated_ops3(i32 %in, i32 %in2, i32 %in3) {
; CHECK-LABEL: @unrelated_ops3(
; CHECK-NEXT: [[AND0:%.*]] = and i32 [[IN:%.*]], 2
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[AND0]], 0
; CHECK-NEXT: [[TEMP:%.*]] = select i1 [[CMP]], i32 0, i32 144
; CHECK-NEXT: [[TEMP3:%.*]] = or disjoint i32 [[TEMP]], [[IN3:%.*]]
; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[IN]], 12
; CHECK-NEXT: [[TEMP2:%.*]] = mul nuw nsw i32 [[TMP2]], 72
; CHECK-NEXT: [[TEMP4:%.*]] = or disjoint i32 [[IN2:%.*]], [[TEMP2]]
; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TEMP3]], [[TEMP4]]
; CHECK-NEXT: ret i32 [[OUT]]
;
%and0 = and i32 %in, 2
%cmp = icmp eq i32 %and0, 0
%temp = select i1 %cmp, i32 0, i32 144
%temp3 = or disjoint i32 %temp, %in3
%and1 = and i32 %in, 12
%temp2 = mul nuw nsw i32 %and1, 72
%temp4 = or disjoint i32 %in2, %temp2
%out = or disjoint i32 %temp3, %temp4
ret i32 %out
}

define i32 @unrelated_ops4(i32 %in, i32 %in2, i32 %in3) {
; CHECK-LABEL: @unrelated_ops4(
; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[IN:%.*]], 12
; CHECK-NEXT: [[TMP3:%.*]] = mul nuw nsw i32 [[TMP2]], 72
; CHECK-NEXT: [[TEMP3:%.*]] = or disjoint i32 [[IN4:%.*]], [[TMP3]]
; CHECK-NEXT: [[AND1:%.*]] = and i32 [[IN]], 2
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[AND1]], 0
; CHECK-NEXT: [[IN3:%.*]] = select i1 [[CMP]], i32 0, i32 144
; CHECK-NEXT: [[TMP4:%.*]] = or disjoint i32 [[IN3]], [[IN2:%.*]]
; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TEMP3]], [[TMP4]]
; CHECK-NEXT: ret i32 [[OUT]]
;
%and0 = and i32 %in, 12
%temp = mul nuw nsw i32 %and0, 72
%temp3 = or disjoint i32 %in2, %temp
%and1 = and i32 %in, 2
%cmp = icmp eq i32 %and1, 0
%temp2 = select i1 %cmp, i32 0, i32 144
%temp4 = or disjoint i32 %temp2, %in3
%out = or disjoint i32 %temp3, %temp4
ret i32 %out
}

define i32 @unrelated_ops5(i32 %in, i32 %in2, i32 %in3) {
; CHECK-LABEL: @unrelated_ops5(
; CHECK-NEXT: [[AND0:%.*]] = and i32 [[IN:%.*]], 2
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[AND0]], 0
; CHECK-NEXT: [[IN3:%.*]] = select i1 [[CMP]], i32 0, i32 144
; CHECK-NEXT: [[TMP4:%.*]] = or disjoint i32 [[IN3]], [[IN2:%.*]]
; CHECK-NEXT: [[AND1:%.*]] = and i32 [[IN]], 4
; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i32 [[AND1]], 0
; CHECK-NEXT: [[TEMP2:%.*]] = select i1 [[CMP2]], i32 0, i32 288
; CHECK-NEXT: [[TMP3:%.*]] = or disjoint i32 [[IN4:%.*]], [[TEMP2]]
; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TMP4]], [[TMP3]]
; CHECK-NEXT: ret i32 [[OUT]]
;
%and0 = and i32 %in, 2
%cmp = icmp eq i32 %and0, 0
%temp = select i1 %cmp, i32 0, i32 144
%temp3 = or disjoint i32 %temp, %in3
%and1 = and i32 %in, 4
%cmp2 = icmp eq i32 %and1, 0
%temp2 = select i1 %cmp2, i32 0, i32 288
%temp4 = or disjoint i32 %in2, %temp2
%out = or disjoint i32 %temp3, %temp4
ret i32 %out
}

define i32 @unrelated_ops6(i32 %in, i32 %in2, i32 %in3) {
; CHECK-LABEL: @unrelated_ops6(
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
; CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i32 [[TMP1]], 72
; CHECK-NEXT: [[TEMP3:%.*]] = or disjoint i32 [[IN4:%.*]], [[TMP2]]
; CHECK-NEXT: [[AND1:%.*]] = and i32 [[IN]], 12
; CHECK-NEXT: [[IN2:%.*]] = mul nuw nsw i32 [[AND1]], 72
; CHECK-NEXT: [[TMP3:%.*]] = or disjoint i32 [[IN3:%.*]], [[IN2]]
; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TEMP3]], [[TMP3]]
; CHECK-NEXT: ret i32 [[OUT]]
;
%and0 = and i32 %in, 3
%temp = mul nuw nsw i32 %and0, 72
%temp3 = or disjoint i32 %in3, %temp
%and1 = and i32 %in, 12
%temp2 = mul nuw nsw i32 %and1, 72
%temp4 = or disjoint i32 %in2, %temp2
%out = or disjoint i32 %temp3, %temp4
ret i32 %out
}

define i32 @unrelated_ops7(i32 %in, i32 %in2, i32 %in3) {
; CHECK-LABEL: @unrelated_ops7(
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
; CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i32 [[TMP1]], 72
; CHECK-NEXT: [[TEMP3:%.*]] = or disjoint i32 [[IN4:%.*]], [[TMP2]]
; CHECK-NEXT: [[AND1:%.*]] = and i32 [[IN]], 12
; CHECK-NEXT: [[IN3:%.*]] = mul nuw nsw i32 [[AND1]], 72
; CHECK-NEXT: [[TMP3:%.*]] = or disjoint i32 [[IN3]], [[IN2:%.*]]
; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TEMP3]], [[TMP3]]
; CHECK-NEXT: ret i32 [[OUT]]
;
%and0 = and i32 %in, 3
%temp = mul nuw nsw i32 %and0, 72
%temp3 = or disjoint i32 %in3, %temp
%and1 = and i32 %in, 12
%temp2 = mul nuw nsw i32 %and1, 72
%temp4 = or disjoint i32 %temp2, %in2
%out = or disjoint i32 %temp3, %temp4
ret i32 %out
}

define i32 @unrelated_ops8(i32 %in, i32 %in2, i32 %in3) {
; CHECK-LABEL: @unrelated_ops8(
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
; CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i32 [[TMP1]], 72
; CHECK-NEXT: [[TEMP3:%.*]] = or disjoint i32 [[TMP2]], [[IN4:%.*]]
; CHECK-NEXT: [[AND1:%.*]] = and i32 [[IN]], 12
; CHECK-NEXT: [[IN3:%.*]] = mul nuw nsw i32 [[AND1]], 72
; CHECK-NEXT: [[TMP3:%.*]] = or disjoint i32 [[IN3]], [[IN2:%.*]]
; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TEMP3]], [[TMP3]]
; CHECK-NEXT: ret i32 [[OUT]]
;
%and0 = and i32 %in, 3
%temp = mul nuw nsw i32 %and0, 72
%temp3 = or disjoint i32 %temp, %in3
%and1 = and i32 %in, 12
%temp2 = mul nuw nsw i32 %and1, 72
%temp4 = or disjoint i32 %temp2, %in2
%out = or disjoint i32 %temp3, %temp4
ret i32 %out
}

define i32 @unrelated_ops_nocombine(i32 %in, i32 %in2, i32 %in3) {
; CHECK-LABEL: @unrelated_ops_nocombine(
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
; CHECK-NEXT: [[TEMP:%.*]] = mul nuw nsw i32 [[TMP1]], 72
; CHECK-NEXT: [[TMP4:%.*]] = or disjoint i32 [[TEMP]], [[IN3:%.*]]
; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[IN]], 7
; CHECK-NEXT: [[TEMP2:%.*]] = mul nuw nsw i32 [[TMP2]], 72
; CHECK-NEXT: [[TMP3:%.*]] = or disjoint i32 [[IN2:%.*]], [[TEMP2]]
; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TMP4]], [[TMP3]]
; CHECK-NEXT: ret i32 [[OUT]]
;
%and0 = and i32 %in, 3
%temp = mul nuw nsw i32 %and0, 72
%temp3 = or disjoint i32 %temp, %in3
%and1 = and i32 %in, 7
%temp2 = mul nuw nsw i32 %and1, 72
%temp4 = or disjoint i32 %in2, %temp2
%out = or disjoint i32 %temp3, %temp4
ret i32 %out
}

define i32 @unrelated_ops_nocombine1(i32 %in, i32 %in2, i32 %in3) {
; CHECK-LABEL: @unrelated_ops_nocombine1(
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
; CHECK-NEXT: [[TEMP:%.*]] = mul nuw nsw i32 [[TMP1]], 72
; CHECK-NEXT: [[TMP4:%.*]] = or disjoint i32 [[TEMP]], [[IN3:%.*]]
; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[IN]], 12
; CHECK-NEXT: [[TEMP2:%.*]] = mul nuw nsw i32 [[TMP2]], 36
; CHECK-NEXT: [[TMP3:%.*]] = or disjoint i32 [[IN2:%.*]], [[TEMP2]]
; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TMP4]], [[TMP3]]
; CHECK-NEXT: ret i32 [[OUT]]
;
%and0 = and i32 %in, 3
%temp = mul nuw nsw i32 %and0, 72
%temp3 = or disjoint i32 %temp, %in3
%and1 = and i32 %in, 12
%temp2 = mul nuw nsw i32 %and1, 36
%temp4 = or disjoint i32 %in2, %temp2
%out = or disjoint i32 %temp3, %temp4
ret i32 %out
}

define i32 @no_chain(i32 %in, i32 %in2, i32 %in3) {
; CHECK-LABEL: @no_chain(
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
; CHECK-NEXT: [[TEMP:%.*]] = mul nuw nsw i32 [[TMP1]], 72
; CHECK-NEXT: [[TEMP3:%.*]] = or disjoint i32 [[TEMP]], [[IN3:%.*]]
; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TEMP3]], [[IN2:%.*]]
; CHECK-NEXT: ret i32 [[OUT]]
;
%and0 = and i32 %in, 3
%temp = mul nuw nsw i32 %and0, 72
%temp3 = or disjoint i32 %temp, %in3
%out = or disjoint i32 %temp3, %in2
ret i32 %out
}

;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
; CONSTSPLAT: {{.*}}
; CONSTVEC: {{.*}}