diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp index 89668af378070..68fef9dac5bb1 100644 --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -796,19 +796,78 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS, assert((!NoUndefSelfMultiply || LHS == RHS) && "Self multiplication knownbits mismatch"); - // Compute the high known-0 bits by multiplying the unsigned max of each side. - // Conservatively, M active bits * N active bits results in M + N bits in the - // result. But if we know a value is a power-of-2 for example, then this - // computes one more leading zero. - // TODO: This could be generalized to number of sign bits (negative numbers). - APInt UMaxLHS = LHS.getMaxValue(); - APInt UMaxRHS = RHS.getMaxValue(); - - // For leading zeros in the result to be valid, the unsigned max product must - // fit in the bitwidth (it must not overflow). - bool HasOverflow; - APInt UMaxResult = UMaxLHS.umul_ov(UMaxRHS, HasOverflow); - unsigned LeadZ = HasOverflow ? 0 : UMaxResult.countl_zero(); + // Compute the high known-0 or known-1 bits by multiplying the min and max of + // each side. + APInt MaxLHS = LHS.isNegative() ? LHS.getMinValue().abs() : LHS.getMaxValue(), + MaxRHS = RHS.isNegative() ? RHS.getMinValue().abs() : RHS.getMaxValue(), + MinLHS = LHS.isNegative() ? LHS.getMaxValue().abs() : LHS.getMinValue(), + MinRHS = RHS.isNegative() ? RHS.getMaxValue().abs() : RHS.getMinValue(); + + APInt MaxProduct = MaxLHS * MaxRHS, MinProduct = MinLHS * MinRHS; + + if (LHS.isNegative() != RHS.isNegative()) { + // The unsigned-multiplication wrapped MinProduct and MaxProduct can be + // negated to turn them into the corresponding signed-multiplication + // wrapped values. + MinProduct.negate(); + MaxProduct.negate(); + + // MinProduct < MaxProduct is now MaxProduct < MinProduct. + std::swap(MinProduct, MaxProduct); + } + + // Unless both MinProduct and MaxProduct are the same sign, there won't be any + // leading zeros or ones in the result. Unless MaxProduct.ugt(MinProduct), it + // is not safe to set any leading zeros or ones. + unsigned LeadZ = 0, LeadO = 0; + if (MinProduct.isNegative() == MaxProduct.isNegative() && + MaxProduct.ugt(MinProduct)) { + APInt LHSUnknown = (~LHS.Zero & ~LHS.One), + RHSUnknown = (~RHS.Zero & ~RHS.One); + + // A product of M active bits * N active bits results in M + N bits in the + // result. If either of the operands is a power of two, the result has one + // less active bit. + auto ProdActiveBits = [](const APInt &A, const APInt &B) { + if (A.isZero() || B.isZero()) + return 0u; + return A.getActiveBits() + B.getActiveBits() - + (A.isPowerOf2() || B.isPowerOf2()); + }; + + // We want to compute the number of active bits in the difference between + // the non-wrapped max product and non-wrapped min product, but we want to + // avoid camputing the non-wrapped max/min product. + unsigned ActiveBitsInDiff = BitWidth + 1; + if (LHSUnknown.isZero()) { + ActiveBitsInDiff = + ProdActiveBits(MinLHS.isZero() ? LHSUnknown : MinLHS, RHSUnknown); + } else if (RHSUnknown.isZero()) { + ActiveBitsInDiff = + ProdActiveBits(MinRHS.isZero() ? RHSUnknown : MinRHS, LHSUnknown); + } else if (ProdActiveBits(MinLHS, RHSUnknown) <= BitWidth && + ProdActiveBits(MinRHS, LHSUnknown) <= BitWidth && + ProdActiveBits(LHSUnknown, RHSUnknown) <= BitWidth) { + // Slow path, which is seldom taken in practice. + // (MinLHS + LHSUnknown) * (MinRHS + RHSUnknown) - (MinLHS * MinRHS) + // = MinLHS * RHSUnknown + MinRHS * LHSUnknown + LHSUnknown * RHSUnknown. + APInt Res = MinLHS.umul_sat(RHSUnknown) + .uadd_sat(MinRHS.umul_sat(LHSUnknown)) + .uadd_sat(LHSUnknown.umul_sat(RHSUnknown)); + if (!Res.isMaxValue()) + ActiveBitsInDiff = Res.getActiveBits(); + } + + // We uniformly handle the case where there is no max-overflow, in which + // case the high zeros and ones are computed optimally, and where there is, + // but the result shifts at most by BitWidth, in which case the high zeros + // and ones are not computed optimally. + if (ActiveBitsInDiff <= BitWidth) { + // Set the minimum leading zeros or ones from MaxProduct and MinProduct. + LeadZ = MaxProduct.countLeadingZeros(); + LeadO = MinProduct.countLeadingOnes(); + } + } // The result of the bottom bits of an integer multiply can be // inferred by looking at the bottom bits of both operands and @@ -873,8 +932,9 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS, KnownBits Res(BitWidth); Res.Zero.setHighBits(LeadZ); + Res.One.setHighBits(LeadO); Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown); - Res.One = BottomKnown.getLoBits(ResultBitsKnown); + Res.One |= BottomKnown.getLoBits(ResultBitsKnown); // If we're self-multiplying then bit[1] is guaranteed to be zero. if (NoUndefSelfMultiply && BitWidth > 1) { diff --git a/llvm/test/Analysis/ValueTracking/knownbits-mul.ll b/llvm/test/Analysis/ValueTracking/knownbits-mul.ll new file mode 100644 index 0000000000000..37526c67f0d9e --- /dev/null +++ b/llvm/test/Analysis/ValueTracking/knownbits-mul.ll @@ -0,0 +1,143 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt < %s -passes=instcombine -S | FileCheck %s + +define i8 @mul_low_bits_know(i8 %xx, i8 %yy) { +; CHECK-LABEL: define i8 @mul_low_bits_know( +; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) { +; CHECK-NEXT: ret i8 0 +; + %x = and i8 %xx, 2 + %y = and i8 %yy, 4 + %mul = mul i8 %x, %y + %r = and i8 %mul, 6 + ret i8 %r +} + +define i8 @mul_low_bits_know2(i8 %xx, i8 %yy) { +; CHECK-LABEL: define i8 @mul_low_bits_know2( +; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) { +; CHECK-NEXT: ret i8 0 +; + %x = or i8 %xx, -2 + %y = and i8 %yy, 4 + %mul = mul i8 %x, %y + %r = and i8 %mul, 2 + ret i8 %r +} + +define i8 @mul_low_bits_partially_known(i8 %xx, i8 %yy) { +; CHECK-LABEL: define i8 @mul_low_bits_partially_known( +; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) { +; CHECK-NEXT: [[Y:%.*]] = or i8 [[YY]], 2 +; CHECK-NEXT: [[MUL:%.*]] = sub nsw i8 0, [[Y]] +; CHECK-NEXT: [[R:%.*]] = and i8 [[MUL]], 2 +; CHECK-NEXT: ret i8 [[R]] +; + %x = or i8 %xx, -4 + %x.notsmin = or i8 %x, 3 + %y = or i8 %yy, -2 + %mul = mul i8 %x.notsmin, %y + %r = and i8 %mul, 6 + ret i8 %r +} + +define i8 @mul_low_bits_unknown(i8 %xx, i8 %yy) { +; CHECK-LABEL: define i8 @mul_low_bits_unknown( +; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) { +; CHECK-NEXT: [[X:%.*]] = or i8 [[XX]], 4 +; CHECK-NEXT: [[Y:%.*]] = or i8 [[YY]], 6 +; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[X]], [[Y]] +; CHECK-NEXT: [[R:%.*]] = and i8 [[MUL]], 6 +; CHECK-NEXT: ret i8 [[R]] +; + %x = or i8 %xx, -4 + %y = or i8 %yy, -2 + %mul = mul i8 %x, %y + %r = and i8 %mul, 6 + ret i8 %r +} + +define i8 @mul_high_bits_know(i8 %xx, i8 %yy) { +; CHECK-LABEL: define i8 @mul_high_bits_know( +; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) { +; CHECK-NEXT: ret i8 0 +; + %x = and i8 %xx, 2 + %y = and i8 %yy, 4 + %mul = mul i8 %x, %y + %r = and i8 %mul, 16 + ret i8 %r +} + +define i8 @mul_high_bits_know2(i8 %xx, i8 %yy) { +; CHECK-LABEL: define i8 @mul_high_bits_know2( +; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) { +; CHECK-NEXT: ret i8 -16 +; + %x = or i8 %xx, -2 + %y = and i8 %yy, 4 + %y.nonzero = or i8 %y, 1 + %mul = mul i8 %x, %y.nonzero + %r = and i8 %mul, -16 + ret i8 %r +} + +define i8 @mul_high_bits_know3(i8 %xx, i8 %yy) { +; CHECK-LABEL: define i8 @mul_high_bits_know3( +; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) { +; CHECK-NEXT: ret i8 0 +; + %x = or i8 %xx, -4 + %y = or i8 %yy, -2 + %mul = mul i8 %x, %y + %r = and i8 %mul, -16 + ret i8 %r +} + +define i8 @mul_high_bits_unknown(i8 %xx, i8 %yy) { +; CHECK-LABEL: define i8 @mul_high_bits_unknown( +; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) { +; CHECK-NEXT: [[X:%.*]] = and i8 [[XX]], 2 +; CHECK-NEXT: [[Y:%.*]] = and i8 [[YY]], 4 +; CHECK-NEXT: [[MUL:%.*]] = mul nuw nsw i8 [[X]], [[Y]] +; CHECK-NEXT: ret i8 [[MUL]] +; + %x = and i8 %xx, 2 + %y = and i8 %yy, 4 + %mul = mul i8 %x, %y + %r = and i8 %mul, 8 + ret i8 %r +} + +define i8 @mul_high_bits_unknown2(i8 %xx, i8 %yy) { +; CHECK-LABEL: define i8 @mul_high_bits_unknown2( +; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) { +; CHECK-NEXT: [[X:%.*]] = or i8 [[XX]], -2 +; CHECK-NEXT: [[Y:%.*]] = and i8 [[YY]], 4 +; CHECK-NEXT: [[MUL:%.*]] = mul nsw i8 [[X]], [[Y]] +; CHECK-NEXT: [[R:%.*]] = and i8 [[MUL]], -16 +; CHECK-NEXT: ret i8 [[R]] +; + %x = or i8 %xx, -2 + %y = and i8 %yy, 4 + %mul = mul i8 %x, %y + %r = and i8 %mul, -16 + ret i8 %r +} + +; TODO: This can be reduced to zero. +define i8 @mul_high_bits_unknown3(i8 %xx, i8 %yy) { +; CHECK-LABEL: define i8 @mul_high_bits_unknown3( +; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) { +; CHECK-NEXT: [[X:%.*]] = or i8 [[XX]], 28 +; CHECK-NEXT: [[Y:%.*]] = or i8 [[YY]], 30 +; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[X]], [[Y]] +; CHECK-NEXT: [[R:%.*]] = and i8 [[MUL]], 16 +; CHECK-NEXT: ret i8 [[R]] +; + %x = or i8 %xx, -4 + %y = or i8 %yy, -2 + %mul = mul i8 %x, %y + %r = and i8 %mul, 16 + ret i8 %r +} diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp index b16368de17648..2be2e1d093315 100644 --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -815,7 +815,7 @@ TEST(KnownBitsTest, ConcatBits) { } } -TEST(KnownBitsTest, MulExhaustive) { +TEST(KnownBitsTest, MulLowBitsExhaustive) { for (unsigned Bits : {1, 4}) { ForeachKnownBits(Bits, [&](const KnownBits &Known1) { ForeachKnownBits(Bits, [&](const KnownBits &Known2) { @@ -849,4 +849,151 @@ TEST(KnownBitsTest, MulExhaustive) { } } +TEST(KnownBitsTest, MulHighBitsNoOverflow) { + for (unsigned Bits : {1, 4}) { + ForeachKnownBits(Bits, [&](const KnownBits &Known1) { + ForeachKnownBits(Bits, [&](const KnownBits &Known2) { + KnownBits Computed = KnownBits::mul(Known1, Known2); + KnownBits Exact(Bits), WideExact(2 * Bits); + Exact.Zero.setAllBits(); + Exact.One.setAllBits(); + + bool HasOverflow; + ForeachNumInKnownBits(Known1, [&](const APInt &N1) { + ForeachNumInKnownBits(Known2, [&](const APInt &N2) { + // The final value of HasOverflow corresponds to the multiplication + // in the last iteration, which is the max product. + APInt Res = N1.umul_ov(N2, HasOverflow); + Exact.One &= Res; + Exact.Zero &= ~Res; + }); + }); + + if (!Exact.hasConflict() && !HasOverflow) { + // Check that leading zeros and leading ones are optimal in the + // result, provided there is no overflow. + APInt ZerosMask = + APInt::getHighBitsSet(Bits, Exact.Zero.countLeadingOnes()), + OnesMask = + APInt::getHighBitsSet(Bits, Exact.One.countLeadingOnes()); + + KnownBits ExactZeros(Bits), ComputedZeros(Bits); + KnownBits ExactOnes(Bits), ComputedOnes(Bits); + ExactZeros.Zero.setAllBits(); + ExactZeros.One.setAllBits(); + ExactOnes.Zero.setAllBits(); + ExactOnes.One.setAllBits(); + + ExactZeros.Zero = Exact.Zero & ZerosMask; + ExactZeros.One = Exact.One & ZerosMask; + ComputedZeros.Zero = Computed.Zero & ZerosMask; + ComputedZeros.One = Computed.One & ZerosMask; + EXPECT_TRUE(checkResult("mul", ExactZeros, ComputedZeros, + {Known1, Known2}, + /*CheckOptimality=*/true)); + + ExactOnes.Zero = Exact.Zero & OnesMask; + ExactOnes.One = Exact.One & OnesMask; + ComputedOnes.Zero = Computed.Zero & OnesMask; + ComputedOnes.One = Computed.One & OnesMask; + EXPECT_TRUE(checkResult("mul", ExactOnes, ComputedOnes, + {Known1, Known2}, + /*CheckOptimality=*/true)); + } + }); + }); + } +} + +TEST(KnownBitsTest, MulHighBitsOverflow) { + unsigned Bits = 4; + using KnownUnknownPair = std::pair; + SmallVector> TestPairs = { + {{2, 0}, {7, -1}}, // 001?, 0111 + {{2, -1}, {10, 0}}, // 0010, 101? + {{9, 2}, {9, 1}}, // 1?01, 10?1 + {{5, 1}, {3, 2}}}; // 01?1, 0?11 + for (auto [P1, P2] : TestPairs) { + KnownBits Known1(Bits), Known2(Bits); + auto [K1, U1] = P1; + auto [K2, U2] = P2; + Known1 = KnownBits::makeConstant(APInt(Bits, K1)); + Known2 = KnownBits::makeConstant(APInt(Bits, K2)); + if (U1 > -1) { + Known1.Zero.setBitVal(U1, 0); + Known1.One.setBitVal(U1, 0); + } + if (U2 > -1) { + Known2.Zero.setBitVal(U2, 0); + Known2.One.setBitVal(U2, 0); + } + KnownBits Computed = KnownBits::mul(Known1, Known2); + KnownBits Exact(Bits); + Exact.Zero.setAllBits(); + Exact.One.setAllBits(); + + ForeachNumInKnownBits(Known1, [&](const APInt &N1) { + ForeachNumInKnownBits(Known2, [&](const APInt &N2) { + APInt Res = N1 * N2; + Exact.One &= Res; + Exact.Zero &= ~Res; + }); + }); + + // Check that the leading zeros or ones are optimal for the given examples, + // which overflow. It is certainly sub-optimal on other examples. + APInt ZerosMask = + APInt::getHighBitsSet(Bits, Exact.Zero.countLeadingOnes()), + OnesMask = APInt::getHighBitsSet(Bits, Exact.One.countLeadingOnes()); + + KnownBits ExactZeros(Bits), ComputedZeros(Bits); + KnownBits ExactOnes(Bits), ComputedOnes(Bits); + ExactZeros.Zero.setAllBits(); + ExactZeros.One.setAllBits(); + ExactOnes.Zero.setAllBits(); + ExactOnes.One.setAllBits(); + + ExactZeros.Zero = Exact.Zero & ZerosMask; + ExactZeros.One = Exact.One & ZerosMask; + ComputedZeros.Zero = Computed.Zero & ZerosMask; + ComputedZeros.One = Computed.One & ZerosMask; + EXPECT_TRUE(checkResult("mul", ExactZeros, ComputedZeros, {Known1, Known2}, + /*CheckOptimality=*/true)); + + ExactOnes.Zero = Exact.Zero & OnesMask; + ExactOnes.One = Exact.One & OnesMask; + ComputedOnes.Zero = Computed.Zero & OnesMask; + ComputedOnes.One = Computed.One & OnesMask; + EXPECT_TRUE(checkResult("mul", ExactOnes, ComputedOnes, {Known1, Known2}, + /*CheckOptimality=*/true)); + } +} + +TEST(KnownBitsTest, MulStress) { + // Stress test KnownBits::mul on 5 and 6 bits, checking that the result is + // correct, even if not optimal. + for (unsigned Bits : {5, 6}) { + ForeachKnownBits(Bits, [&](const KnownBits &Known1) { + ForeachKnownBits(Bits, [&](const KnownBits &Known2) { + KnownBits Computed = KnownBits::mul(Known1, Known2); + KnownBits Exact(Bits); + Exact.Zero.setAllBits(); + Exact.One.setAllBits(); + + ForeachNumInKnownBits(Known1, [&](const APInt &N1) { + ForeachNumInKnownBits(Known2, [&](const APInt &N2) { + APInt Res = N1 * N2; + Exact.One &= Res; + Exact.Zero &= ~Res; + }); + }); + + if (!Exact.hasConflict()) { + EXPECT_TRUE(checkResult("mul", Exact, Computed, {Known1, Known2}, + /*CheckOptimality=*/false)); + } + }); + }); + } +} } // end anonymous namespace