@@ -3602,6 +3602,11 @@ struct DecomposedBitMaskMul {
36023602 APInt Mask;
36033603 bool NUW;
36043604 bool NSW;
3605+
3606+ bool isCombineableWith (const DecomposedBitMaskMul Other) {
3607+ return X == Other.X && !Mask.intersects (Other.Mask ) &&
3608+ Factor == Other.Factor ;
3609+ }
36053610};
36063611
36073612static std::optional<DecomposedBitMaskMul> matchBitmaskMul (Value *V) {
@@ -3659,6 +3664,59 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
36593664 return std::nullopt ;
36603665}
36613666
3667+ // / (A & N) * C + (A & M) * C -> (A & (N + M)) & C
3668+ // / This also accepts the equivalent select form of (A & N) * C
3669+ // / expressions i.e. !(A & N) ? 0 : N * C)
3670+ static Value *foldBitmaskMul (Value *Op0, Value *Op1,
3671+ InstCombiner::BuilderTy &Builder) {
3672+ auto Decomp1 = matchBitmaskMul (Op1);
3673+ if (!Decomp1)
3674+ return nullptr ;
3675+
3676+ auto Decomp0 = matchBitmaskMul (Op0);
3677+ if (!Decomp0)
3678+ return nullptr ;
3679+
3680+ if (Decomp0->isCombineableWith (*Decomp1)) {
3681+ Value *NewAnd = Builder.CreateAnd (
3682+ Decomp0->X ,
3683+ ConstantInt::get (Decomp0->X ->getType (), Decomp0->Mask + Decomp1->Mask ));
3684+
3685+ return Builder.CreateMul (
3686+ NewAnd, ConstantInt::get (NewAnd->getType (), Decomp1->Factor ), " " ,
3687+ Decomp0->NUW && Decomp1->NUW , Decomp0->NSW && Decomp1->NSW );
3688+ }
3689+
3690+ return nullptr ;
3691+ }
3692+
3693+ Value *InstCombinerImpl::foldDisjointOr (Value *LHS, Value *RHS) {
3694+ if (Value *Res = foldBitmaskMul (LHS, RHS, Builder))
3695+ return Res;
3696+
3697+ return nullptr ;
3698+ }
3699+
3700+ Value *InstCombinerImpl::reassociateDisjointOr (Value *LHS, Value *RHS) {
3701+
3702+ Value *X, *Y;
3703+ if (match (RHS, m_OneUse (m_DisjointOr (m_Value (X), m_Value (Y))))) {
3704+ if (Value *Res = foldDisjointOr (LHS, X))
3705+ return Builder.CreateOr (Res, Y, " " , /* IsDisjoint=*/ true );
3706+ if (Value *Res = foldDisjointOr (LHS, Y))
3707+ return Builder.CreateOr (Res, X, " " , /* IsDisjoint=*/ true );
3708+ }
3709+
3710+ if (match (LHS, m_OneUse (m_DisjointOr (m_Value (X), m_Value (Y))))) {
3711+ if (Value *Res = foldDisjointOr (X, RHS))
3712+ return Builder.CreateOr (Res, Y, " " , /* IsDisjoint=*/ true );
3713+ if (Value *Res = foldDisjointOr (Y, RHS))
3714+ return Builder.CreateOr (Res, X, " " , /* IsDisjoint=*/ true );
3715+ }
3716+
3717+ return nullptr ;
3718+ }
3719+
36623720// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
36633721// here. We should standardize that construct where it is needed or choose some
36643722// other way to ensure that commutated variants of patterns are not missed.
@@ -3741,28 +3799,11 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
37413799 /* NSW=*/ true , /* NUW=*/ true ))
37423800 return R;
37433801
3744- // (A & N) * C + (A & M) * C -> (A & (N + M)) & C
3745- // This also accepts the equivalent select form of (A & N) * C
3746- // expressions i.e. !(A & N) ? 0 : N * C)
3747- auto Decomp1 = matchBitmaskMul (I.getOperand (1 ));
3748- if (Decomp1) {
3749- auto Decomp0 = matchBitmaskMul (I.getOperand (0 ));
3750- if (Decomp0 && Decomp0->X == Decomp1->X &&
3751- (Decomp0->Mask & Decomp1->Mask ).isZero () &&
3752- Decomp0->Factor == Decomp1->Factor ) {
3753-
3754- Value *NewAnd = Builder.CreateAnd (
3755- Decomp0->X , ConstantInt::get (Decomp0->X ->getType (),
3756- (Decomp0->Mask + Decomp1->Mask )));
3757-
3758- auto *Combined = BinaryOperator::CreateMul (
3759- NewAnd, ConstantInt::get (NewAnd->getType (), Decomp1->Factor ));
3760-
3761- Combined->setHasNoUnsignedWrap (Decomp0->NUW && Decomp1->NUW );
3762- Combined->setHasNoSignedWrap (Decomp0->NSW && Decomp1->NSW );
3763- return Combined;
3764- }
3765- }
3802+ if (Value *Res = foldBitmaskMul (I.getOperand (0 ), I.getOperand (1 ), Builder))
3803+ return replaceInstUsesWith (I, Res);
3804+
3805+ if (Value *Res = reassociateDisjointOr (I.getOperand (0 ), I.getOperand (1 )))
3806+ return replaceInstUsesWith (I, Res);
37663807 }
37673808
37683809 Value *X, *Y;
0 commit comments