@@ -803,12 +803,7 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
803803 MinLHS = LHS.isNegative () ? LHS.getMaxValue ().abs () : LHS.getMinValue (),
804804 MinRHS = RHS.isNegative () ? RHS.getMaxValue ().abs () : RHS.getMinValue ();
805805
806- // If MaxProduct doesn't overflow, it implies that MinProduct also won't
807- // overflow. However, if MaxProduct overflows, there is no guarantee on the
808- // MinProduct overflowing.
809- bool HasOverflow;
810- APInt MaxProduct = MaxLHS.umul_ov (MaxRHS, HasOverflow),
811- MinProduct = MinLHS * MinRHS;
806+ APInt MaxProduct = MaxLHS * MaxRHS, MinProduct = MinLHS * MinRHS;
812807
813808 if (LHS.isNegative () != RHS.isNegative ()) {
814809 // The unsigned-multiplication wrapped MinProduct and MaxProduct can be
@@ -822,45 +817,52 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
822817 }
823818
824819 // Unless both MinProduct and MaxProduct are the same sign, there won't be any
825- // leading zeros or ones in the result.
820+ // leading zeros or ones in the result. Unless MaxProduct.ugt(MinProduct), it
821+ // is not safe to set any leading zeros or ones.
826822 unsigned LeadZ = 0 , LeadO = 0 ;
827- if (MinProduct.isNegative () == MaxProduct.isNegative ()) {
823+ if (MinProduct.isNegative () == MaxProduct.isNegative () &&
824+ MaxProduct.ugt (MinProduct)) {
828825 APInt LHSUnknown = (~LHS.Zero & ~LHS.One ),
829826 RHSUnknown = (~RHS.Zero & ~RHS.One );
830827
831828 // A product of M active bits * N active bits results in M + N bits in the
832829 // result. If either of the operands is a power of two, the result has one
833830 // less active bit.
834- auto ProdActiveBits = [](const APInt &A, const APInt &B) -> unsigned {
831+ auto ProdActiveBits = [](const APInt &A, const APInt &B) {
835832 if (A.isZero () || B.isZero ())
836- return 0 ;
833+ return 0u ;
837834 return A.getActiveBits () + B.getActiveBits () -
838835 (A.isPowerOf2 () || B.isPowerOf2 ());
839836 };
840837
841838 // We want to compute the number of active bits in the difference between
842839 // the non-wrapped max product and non-wrapped min product, but we want to
843840 // avoid camputing the non-wrapped max/min product.
844- unsigned ActiveBitsInDiff;
845- if (MinLHS.isZero () && MinRHS.isZero ())
846- ActiveBitsInDiff = ProdActiveBits (LHSUnknown, RHSUnknown);
847- else
841+ unsigned ActiveBitsInDiff = BitWidth + 1 ;
842+ if (LHSUnknown.isZero ()) {
843+ ActiveBitsInDiff =
844+ ProdActiveBits (MinLHS.isZero () ? LHSUnknown : MinLHS, RHSUnknown);
845+ } else if (RHSUnknown.isZero ()) {
848846 ActiveBitsInDiff =
849- ProdActiveBits (MinLHS.isZero () ? LHSUnknown : MinLHS, RHSUnknown) +
850847 ProdActiveBits (MinRHS.isZero () ? RHSUnknown : MinRHS, LHSUnknown);
851-
852- // Checks that A.ugt(B), excluding the degenerate case where A is all-ones
853- // and B is zero.
854- auto UgtCheckCorner = [](const APInt &A, const APInt &B) {
855- return (!A.isAllOnes () || !B.isZero ()) && A.ugt (B);
856- };
848+ } else if (ProdActiveBits (MinLHS, RHSUnknown) <= BitWidth &&
849+ ProdActiveBits (MinRHS, LHSUnknown) <= BitWidth &&
850+ ProdActiveBits (LHSUnknown, RHSUnknown) <= BitWidth) {
851+ // Slow path, which is seldom taken in practice.
852+ // (MinLHS + LHSUnknown) * (MinRHS + RHSUnknown) - (MinLHS * MinRHS)
853+ // = MinLHS * RHSUnknown + MinRHS * LHSUnknown + LHSUnknown * RHSUnknown.
854+ APInt Res = MinLHS.umul_sat (RHSUnknown)
855+ .uadd_sat (MinRHS.umul_sat (LHSUnknown))
856+ .uadd_sat (LHSUnknown.umul_sat (RHSUnknown));
857+ if (!Res.isMaxValue ())
858+ ActiveBitsInDiff = Res.getActiveBits ();
859+ }
857860
858861 // We uniformly handle the case where there is no max-overflow, in which
859862 // case the high zeros and ones are computed optimally, and where there is,
860863 // but the result shifts at most by BitWidth, in which case the high zeros
861864 // and ones are not computed optimally.
862- if ((!HasOverflow || ActiveBitsInDiff <= BitWidth) &&
863- UgtCheckCorner (MaxProduct, MinProduct)) {
865+ if (ActiveBitsInDiff <= BitWidth) {
864866 // Set the minimum leading zeros or ones from MaxProduct and MinProduct.
865867 LeadZ = MaxProduct.countLeadingZeros ();
866868 LeadO = MinProduct.countLeadingOnes ();
0 commit comments