@@ -3593,10 +3593,16 @@ static Value *foldOrOfInversions(BinaryOperator &I,
35933593 return nullptr ;
35943594}
35953595
3596+ // A decomposition of ((A & N) ? 0 : N * C) . Where X = A, Factor = C, Mask = N.
3597+ // The NUW / NSW bools
3598+ // Note that we can decompose equivalent forms of this expression (e.g. ((A & N)
3599+ // * C))
35963600struct DecomposedBitMaskMul {
35973601 Value *X;
35983602 APInt Factor;
35993603 APInt Mask;
3604+ bool NUW;
3605+ bool NSW;
36003606};
36013607
36023608static std::optional<DecomposedBitMaskMul> matchBitmaskMul (Value *V) {
@@ -3606,28 +3612,29 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
36063612
36073613 Value *MulOp = nullptr ;
36083614 const APInt *MulConst = nullptr ;
3615+
3616+ // Decompose (A & N) * C) into BitMaskMul
36093617 if (match (Op, m_Mul (m_Value (MulOp), m_APInt (MulConst)))) {
36103618 Value *Original = nullptr ;
36113619 const APInt *Mask = nullptr ;
3612- if (! MulConst->isStrictlyPositive ())
3620+ if (MulConst->isZero ())
36133621 return std::nullopt ;
36143622
36153623 if (match (MulOp, m_And (m_Value (Original), m_APInt (Mask)))) {
3616- if (! Mask->isStrictlyPositive ())
3624+ if (Mask->isZero ())
36173625 return std::nullopt ;
3618- DecomposedBitMaskMul Ret;
3619- Ret.X = Original;
3620- Ret.Mask = *Mask;
3621- Ret.Factor = *MulConst;
3622- return Ret;
3626+ return std::optional<DecomposedBitMaskMul>(
3627+ {Original, *MulConst, *Mask,
3628+ cast<BinaryOperator>(Op)->hasNoUnsignedWrap (),
3629+ cast<BinaryOperator>(Op)->hasNoSignedWrap ()});
36233630 }
36243631 return std::nullopt ;
36253632 }
36263633
36273634 Value *Cond = nullptr ;
36283635 const APInt *EqZero = nullptr , *NeZero = nullptr ;
36293636
3630- // (! (A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
3637+ // Decompose ( (A & N) ? 0 : N * C) into BitMaskMul
36313638 if (match (Op, m_Select (m_Value (Cond), m_APInt (EqZero), m_APInt (NeZero)))) {
36323639 auto ICmpDecompose =
36333640 decomposeBitTest (Cond, /* LookThruTrunc=*/ true ,
@@ -3638,22 +3645,20 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
36383645 if (ICmpDecompose->Pred == ICmpInst::ICMP_NE)
36393646 std::swap (EqZero, NeZero);
36403647
3641- if (!EqZero->isZero () || ! NeZero->isStrictlyPositive ())
3648+ if (!EqZero->isZero () || NeZero->isZero ())
36423649 return std::nullopt ;
36433650
36443651 if (!ICmpInst::isEquality (ICmpDecompose->Pred ) ||
36453652 !ICmpDecompose->C .isZero () || !ICmpDecompose->Mask .isPowerOf2 () ||
3646- ICmpDecompose->Mask .isNegative ())
3653+ ICmpDecompose->Mask .isZero ())
36473654 return std::nullopt ;
36483655
36493656 if (!NeZero->urem (ICmpDecompose->Mask ).isZero ())
36503657 return std::nullopt ;
36513658
3652- DecomposedBitMaskMul Ret;
3653- Ret.X = ICmpDecompose->X ;
3654- Ret.Mask = ICmpDecompose->Mask ;
3655- Ret.Factor = NeZero->udiv (ICmpDecompose->Mask );
3656- return Ret;
3659+ return std::optional<DecomposedBitMaskMul>(
3660+ {ICmpDecompose->X , NeZero->udiv (ICmpDecompose->Mask ),
3661+ ICmpDecompose->Mask , /* NUW=*/ false , /* NSW=*/ false });
36573662 }
36583663
36593664 return std::nullopt ;
@@ -3741,19 +3746,26 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
37413746 /* NSW=*/ true , /* NUW=*/ true ))
37423747 return R;
37433748
3744- auto Decomp0 = matchBitmaskMul (I.getOperand (0 ));
3749+ // (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
3750+ // This also accepts the equivalent mul form of (A & N) ? 0 : N * C)
3751+ // expressions i.e. (A & N) * C
37453752 auto Decomp1 = matchBitmaskMul (I.getOperand (1 ));
3746-
3747- if ( Decomp0 && Decomp1) {
3748- if (Decomp0->X == Decomp1->X &&
3753+ if (Decomp1) {
3754+ auto Decomp0 = matchBitmaskMul (I. getOperand ( 0 ));
3755+ if (Decomp0 && Decomp0 ->X == Decomp1->X &&
37493756 (Decomp0->Mask & Decomp1->Mask ).isZero () &&
37503757 Decomp0->Factor == Decomp1->Factor ) {
3758+
37513759 auto NewAnd = Builder.CreateAnd (
37523760 Decomp0->X , ConstantInt::get (Decomp0->X ->getType (),
37533761 (Decomp0->Mask + Decomp1->Mask )));
37543762
3755- return BinaryOperator::CreateMul (
3763+ auto Combined = BinaryOperator::CreateMul (
37563764 NewAnd, ConstantInt::get (NewAnd->getType (), Decomp1->Factor ));
3765+
3766+ Combined->setHasNoUnsignedWrap (Decomp0->NUW && Decomp1->NUW );
3767+ Combined->setHasNoSignedWrap (Decomp0->NSW && Decomp1->NSW );
3768+ return Combined;
37573769 }
37583770 }
37593771 }
0 commit comments