@@ -3602,6 +3602,11 @@ struct DecomposedBitMaskMul {
36023602 APInt Mask;
36033603 bool NUW;
36043604 bool NSW;
3605+
3606+ bool isCombineableWith (DecomposedBitMaskMul Other) {
3607+ return X == Other.X && (Mask & Other.Mask ).isZero () &&
3608+ Factor == Other.Factor ;
3609+ }
36053610};
36063611
36073612static std::optional<DecomposedBitMaskMul> matchBitmaskMul (Value *V) {
@@ -3659,6 +3664,34 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
36593664 return std::nullopt ;
36603665}
36613666
3667+ using CombinedBitmaskMul =
3668+ std::pair<std::optional<DecomposedBitMaskMul>, Value *>;
3669+
3670+ static CombinedBitmaskMul matchCombinedBitmaskMul (Value *V) {
3671+ auto DecompBitMaskMul = matchBitmaskMul (V);
3672+ if (DecompBitMaskMul)
3673+ return {DecompBitMaskMul, nullptr };
3674+
3675+ // Otherwise, check the operands of V for bitmaskmul pattern
3676+ auto BOp = dyn_cast<BinaryOperator>(V);
3677+ if (!BOp)
3678+ return {std::nullopt , nullptr };
3679+
3680+ auto Disj = dyn_cast<PossiblyDisjointInst>(BOp);
3681+ if (!Disj || !Disj->isDisjoint ())
3682+ return {std::nullopt , nullptr };
3683+
3684+ auto DecompBitMaskMul0 = matchBitmaskMul (BOp->getOperand (0 ));
3685+ if (DecompBitMaskMul0)
3686+ return {DecompBitMaskMul0, BOp->getOperand (1 )};
3687+
3688+ auto DecompBitMaskMul1 = matchBitmaskMul (BOp->getOperand (1 ));
3689+ if (DecompBitMaskMul1)
3690+ return {DecompBitMaskMul1, BOp->getOperand (0 )};
3691+
3692+ return {std::nullopt , nullptr };
3693+ }
3694+
36623695// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
36633696// here. We should standardize that construct where it is needed or choose some
36643697// other way to ensure that commutated variants of patterns are not missed.
@@ -3741,25 +3774,46 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
37413774 /* NSW=*/ true , /* NUW=*/ true ))
37423775 return R;
37433776
3744- // (A & N) * C + (A & M) * C -> (A & (N + M)) & C
3745- // This also accepts the equivalent select form of (A & N) * C
3746- // expressions i.e. !(A & N) ? 0 : N * C)
3747- auto Decomp1 = matchBitmaskMul (I.getOperand (1 ));
3748- if (Decomp1) {
3749- auto Decomp0 = matchBitmaskMul (I.getOperand (0 ));
3750- if (Decomp0 && Decomp0->X == Decomp1->X &&
3751- (Decomp0->Mask & Decomp1->Mask ).isZero () &&
3752- Decomp0->Factor == Decomp1->Factor ) {
3753-
3754- Value *NewAnd = Builder.CreateAnd (
3755- Decomp0->X , ConstantInt::get (Decomp0->X ->getType (),
3756- (Decomp0->Mask + Decomp1->Mask )));
3757-
3758- auto *Combined = BinaryOperator::CreateMul (
3759- NewAnd, ConstantInt::get (NewAnd->getType (), Decomp1->Factor ));
3760-
3761- Combined->setHasNoUnsignedWrap (Decomp0->NUW && Decomp1->NUW );
3762- Combined->setHasNoSignedWrap (Decomp0->NSW && Decomp1->NSW );
3777+ // (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
3778+ // This also accepts the equivalent mul form of (A & N) ? 0 : N * C)
3779+ // expressions i.e. (A & N) * C
3780+ CombinedBitmaskMul Decomp1 = matchCombinedBitmaskMul (I.getOperand (1 ));
3781+ auto BMDecomp1 = Decomp1.first ;
3782+
3783+ if (BMDecomp1) {
3784+ CombinedBitmaskMul Decomp0 = matchCombinedBitmaskMul (I.getOperand (0 ));
3785+ auto BMDecomp0 = Decomp0.first ;
3786+
3787+ if (BMDecomp0 && BMDecomp0->isCombineableWith (*BMDecomp1)) {
3788+ auto NewAnd = Builder.CreateAnd (
3789+ BMDecomp0->X ,
3790+ ConstantInt::get (BMDecomp0->X ->getType (),
3791+ (BMDecomp0->Mask + BMDecomp1->Mask )));
3792+
3793+ BinaryOperator *Combined = cast<BinaryOperator>(Builder.CreateMul (
3794+ NewAnd, ConstantInt::get (NewAnd->getType (), BMDecomp1->Factor )));
3795+
3796+ Combined->setHasNoUnsignedWrap (BMDecomp0->NUW && BMDecomp1->NUW );
3797+ Combined->setHasNoSignedWrap (BMDecomp0->NSW && BMDecomp1->NSW );
3798+
3799+ // If our tree has indepdent or-disjoint operands, bring them in.
3800+ auto OtherOp0 = Decomp0.second ;
3801+ auto OtherOp1 = Decomp1.second ;
3802+
3803+ if (OtherOp0 || OtherOp1) {
3804+ Value *OtherOp;
3805+ if (OtherOp0 && OtherOp1) {
3806+ OtherOp = Builder.CreateOr (OtherOp0, OtherOp1);
3807+ cast<PossiblyDisjointInst>(OtherOp)->setIsDisjoint (true );
3808+ } else {
3809+ OtherOp = OtherOp0 ? OtherOp0 : OtherOp1;
3810+ }
3811+ Combined = cast<BinaryOperator>(Builder.CreateOr (Combined, OtherOp));
3812+ cast<PossiblyDisjointInst>(Combined)->setIsDisjoint (true );
3813+ }
3814+
3815+ // Caller expects detached instruction
3816+ Combined->removeFromParent ();
37633817 return Combined;
37643818 }
37653819 }
0 commit comments