Skip to content

Commit 41911d5

Browse files
committed
Reassociate instead of combine
Change-Id: Ib86e8ed347ef60948c3e4cb44c5fab1c3667afc6
1 parent 09fb3f3 commit 41911d5

File tree

2 files changed

+98
-37
lines changed

2 files changed

+98
-37
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 40 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3664,32 +3664,35 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
36643664
return std::nullopt;
36653665
}
36663666

3667-
using CombinedBitmaskMul =
3668-
std::pair<std::optional<DecomposedBitMaskMul>, Value *>;
3667+
struct CombinedBitmaskMul {
3668+
std::optional<DecomposedBitMaskMul> Decomp = std::nullopt;
3669+
Value *DecompOp = nullptr;
3670+
Value *OtherOp = nullptr;
3671+
};
36693672

36703673
static CombinedBitmaskMul matchCombinedBitmaskMul(Value *V) {
36713674
auto DecompBitMaskMul = matchBitmaskMul(V);
36723675
if (DecompBitMaskMul)
3673-
return {DecompBitMaskMul, nullptr};
3676+
return {DecompBitMaskMul, V, nullptr};
36743677

36753678
// Otherwise, check the operands of V for bitmaskmul pattern
36763679
auto BOp = dyn_cast<BinaryOperator>(V);
36773680
if (!BOp)
3678-
return {std::nullopt, nullptr};
3681+
return CombinedBitmaskMul();
36793682

36803683
auto Disj = dyn_cast<PossiblyDisjointInst>(BOp);
36813684
if (!Disj || !Disj->isDisjoint())
3682-
return {std::nullopt, nullptr};
3685+
return CombinedBitmaskMul();
36833686

36843687
auto DecompBitMaskMul0 = matchBitmaskMul(BOp->getOperand(0));
36853688
if (DecompBitMaskMul0)
3686-
return {DecompBitMaskMul0, BOp->getOperand(1)};
3689+
return {DecompBitMaskMul0, BOp->getOperand(0), BOp->getOperand(1)};
36873690

36883691
auto DecompBitMaskMul1 = matchBitmaskMul(BOp->getOperand(1));
36893692
if (DecompBitMaskMul1)
3690-
return {DecompBitMaskMul1, BOp->getOperand(0)};
3693+
return {DecompBitMaskMul1, BOp->getOperand(1), BOp->getOperand(0)};
36913694

3692-
return {std::nullopt, nullptr};
3695+
return CombinedBitmaskMul();
36933696
}
36943697

36953698
// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
@@ -3778,43 +3781,44 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
37783781
// This also accepts the equivalent select form of (A & N) * C
37793782
// expressions i.e. !(A & N) ? 0 : N * C)
37803783
CombinedBitmaskMul Decomp1 = matchCombinedBitmaskMul(I.getOperand(1));
3781-
auto BMDecomp1 = Decomp1.first;
3784+
auto BMDecomp1 = Decomp1.Decomp;
37823785

37833786
if (BMDecomp1) {
37843787
CombinedBitmaskMul Decomp0 = matchCombinedBitmaskMul(I.getOperand(0));
3785-
auto BMDecomp0 = Decomp0.first;
3786-
3787-
if (BMDecomp0 && BMDecomp0->isCombineableWith(*BMDecomp1)) {
3788-
auto NewAnd = Builder.CreateAnd(
3789-
BMDecomp0->X,
3790-
ConstantInt::get(BMDecomp0->X->getType(),
3791-
(BMDecomp0->Mask + BMDecomp1->Mask)));
3792-
3793-
BinaryOperator *Combined = cast<BinaryOperator>(Builder.CreateMul(
3794-
NewAnd, ConstantInt::get(NewAnd->getType(), BMDecomp1->Factor)));
3788+
auto BMDecomp0 = Decomp0.Decomp;
37953789

3796-
Combined->setHasNoUnsignedWrap(BMDecomp0->NUW && BMDecomp1->NUW);
3797-
Combined->setHasNoSignedWrap(BMDecomp0->NSW && BMDecomp1->NSW);
3790+
if (BMDecomp0) {
3791+
// If we have independent operands in the BitmaskMul chain, then just
3792+
// reassociate to encourage combining in future iterations.
3793+
if (Decomp0.OtherOp || Decomp1.OtherOp) {
3794+
Value *OtherOp = Decomp0.OtherOp ? Decomp0.OtherOp : Decomp1.OtherOp;
37983795

3799-
// If our tree has indepdent or-disjoint operands, bring them in.
3800-
auto OtherOp0 = Decomp0.second;
3801-
auto OtherOp1 = Decomp1.second;
3802-
3803-
if (OtherOp0 || OtherOp1) {
3804-
Value *OtherOp;
3805-
if (OtherOp0 && OtherOp1) {
3806-
OtherOp = Builder.CreateOr(OtherOp0, OtherOp1);
3796+
if (Decomp0.OtherOp && Decomp1.OtherOp) {
3797+
OtherOp = Builder.CreateOr(Decomp0.OtherOp, Decomp1.OtherOp);
38073798
cast<PossiblyDisjointInst>(OtherOp)->setIsDisjoint(true);
3808-
} else {
3809-
OtherOp = OtherOp0 ? OtherOp0 : OtherOp1;
38103799
}
3811-
Combined = cast<BinaryOperator>(Builder.CreateOr(Combined, OtherOp));
3812-
cast<PossiblyDisjointInst>(Combined)->setIsDisjoint(true);
3800+
3801+
auto CombinedOp =
3802+
Builder.CreateOr(Decomp0.DecompOp, Decomp1.DecompOp);
3803+
cast<PossiblyDisjointInst>(CombinedOp)->setIsDisjoint(true);
3804+
3805+
return BinaryOperator::CreateDisjointOr(CombinedOp, OtherOp);
38133806
}
38143807

3815-
// Caller expects detached instruction
3816-
Combined->removeFromParent();
3817-
return Combined;
3808+
if (BMDecomp0->isCombineableWith(*BMDecomp1)) {
3809+
auto NewAnd = Builder.CreateAnd(
3810+
BMDecomp0->X,
3811+
ConstantInt::get(BMDecomp0->X->getType(),
3812+
(BMDecomp0->Mask + BMDecomp1->Mask)));
3813+
3814+
auto *Combined = BinaryOperator::CreateMul(
3815+
NewAnd, ConstantInt::get(NewAnd->getType(), BMDecomp1->Factor));
3816+
3817+
Combined->setHasNoUnsignedWrap(BMDecomp0->NUW && BMDecomp1->NUW);
3818+
Combined->setHasNoSignedWrap(BMDecomp0->NSW && BMDecomp1->NSW);
3819+
3820+
return Combined;
3821+
}
38183822
}
38193823
}
38203824
}

llvm/test/Transforms/InstCombine/or-bitmask.ll

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,9 +485,9 @@ define i32 @unrelated_ops1(i32 %in, i32 %in2) {
485485

486486
define i32 @unrelated_ops2(i32 %in, i32 %in2, i32 %in3) {
487487
; CHECK-LABEL: @unrelated_ops2(
488+
; CHECK-NEXT: [[TMP3:%.*]] = or disjoint i32 [[IN3:%.*]], [[IN2:%.*]]
488489
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 15
489490
; CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i32 [[TMP1]], 72
490-
; CHECK-NEXT: [[TMP3:%.*]] = or disjoint i32 [[IN3:%.*]], [[IN2:%.*]]
491491
; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TMP2]], [[TMP3]]
492492
; CHECK-NEXT: ret i32 [[OUT]]
493493
;
@@ -501,6 +501,63 @@ define i32 @unrelated_ops2(i32 %in, i32 %in2, i32 %in3) {
501501
ret i32 %out
502502
}
503503

504+
define i32 @unrelated_ops_nocombine(i32 %in, i32 %in2, i32 %in3) {
505+
; CHECK-LABEL: @unrelated_ops_nocombine(
506+
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
507+
; CHECK-NEXT: [[TEMP:%.*]] = mul nuw nsw i32 [[TMP1]], 72
508+
; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[IN]], 7
509+
; CHECK-NEXT: [[TEMP2:%.*]] = mul nuw nsw i32 [[TMP2]], 72
510+
; CHECK-NEXT: [[TMP3:%.*]] = or disjoint i32 [[IN3:%.*]], [[IN2:%.*]]
511+
; CHECK-NEXT: [[TMP4:%.*]] = or disjoint i32 [[TEMP]], [[TEMP2]]
512+
; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TMP4]], [[TMP3]]
513+
; CHECK-NEXT: ret i32 [[OUT]]
514+
;
515+
%1 = and i32 %in, 3
516+
%temp = mul nuw nsw i32 %1, 72
517+
%temp3 = or disjoint i32 %temp, %in3
518+
%2 = and i32 %in, 7
519+
%temp2 = mul nuw nsw i32 %2, 72
520+
%temp4 = or disjoint i32 %in2, %temp2
521+
%out = or disjoint i32 %temp3, %temp4
522+
ret i32 %out
523+
}
524+
525+
define i32 @unrelated_ops_nocombine1(i32 %in, i32 %in2, i32 %in3) {
526+
; CHECK-LABEL: @unrelated_ops_nocombine1(
527+
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
528+
; CHECK-NEXT: [[TEMP:%.*]] = mul nuw nsw i32 [[TMP1]], 72
529+
; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[IN]], 12
530+
; CHECK-NEXT: [[TEMP2:%.*]] = mul nuw nsw i32 [[TMP2]], 36
531+
; CHECK-NEXT: [[TMP3:%.*]] = or disjoint i32 [[IN3:%.*]], [[IN2:%.*]]
532+
; CHECK-NEXT: [[TMP4:%.*]] = or disjoint i32 [[TEMP]], [[TEMP2]]
533+
; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TMP4]], [[TMP3]]
534+
; CHECK-NEXT: ret i32 [[OUT]]
535+
;
536+
%1 = and i32 %in, 3
537+
%temp = mul nuw nsw i32 %1, 72
538+
%temp3 = or disjoint i32 %temp, %in3
539+
%2 = and i32 %in, 12
540+
%temp2 = mul nuw nsw i32 %2, 36
541+
%temp4 = or disjoint i32 %in2, %temp2
542+
%out = or disjoint i32 %temp3, %temp4
543+
ret i32 %out
544+
}
545+
546+
define i32 @no_chain(i32 %in, i32 %in2, i32 %in3) {
547+
; CHECK-LABEL: @no_chain(
548+
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
549+
; CHECK-NEXT: [[TEMP:%.*]] = mul nuw nsw i32 [[TMP1]], 72
550+
; CHECK-NEXT: [[TEMP3:%.*]] = or disjoint i32 [[TEMP]], [[IN3:%.*]]
551+
; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TEMP3]], [[IN2:%.*]]
552+
; CHECK-NEXT: ret i32 [[OUT]]
553+
;
554+
%1 = and i32 %in, 3
555+
%temp = mul nuw nsw i32 %1, 72
556+
%temp3 = or disjoint i32 %temp, %in3
557+
%out = or disjoint i32 %temp3, %in2
558+
ret i32 %out
559+
}
560+
504561
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
505562
; CONSTSPLAT: {{.*}}
506563
; CONSTVEC: {{.*}}

0 commit comments

Comments
 (0)