diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp index 89668af378070..bed1a45568c1c 100644 --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -796,19 +796,25 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS, assert((!NoUndefSelfMultiply || LHS == RHS) && "Self multiplication knownbits mismatch"); - // Compute the high known-0 bits by multiplying the unsigned max of each side. - // Conservatively, M active bits * N active bits results in M + N bits in the - // result. But if we know a value is a power-of-2 for example, then this - // computes one more leading zero. - // TODO: This could be generalized to number of sign bits (negative numbers). - APInt UMaxLHS = LHS.getMaxValue(); - APInt UMaxRHS = RHS.getMaxValue(); - - // For leading zeros in the result to be valid, the unsigned max product must + // Compute the high known-0 or known-1 bits by multiplying the max of each + // side. Conservatively, M active bits * N active bits results in M + N bits + // in the result. But if we know a value is a power-of-2 for example, then + // this computes one more leading zero or one. + APInt MaxLHS = LHS.isNegative() ? LHS.getMinValue().abs() : LHS.getMaxValue(), + MaxRHS = RHS.isNegative() ? RHS.getMinValue().abs() : RHS.getMaxValue(); + + // For leading zeros or ones in the result to be valid, the max product must // fit in the bitwidth (it must not overflow). bool HasOverflow; - APInt UMaxResult = UMaxLHS.umul_ov(UMaxRHS, HasOverflow); - unsigned LeadZ = HasOverflow ? 0 : UMaxResult.countl_zero(); + APInt Result = MaxLHS.umul_ov(MaxRHS, HasOverflow); + unsigned LeadZ = 0, LeadO = 0; + if (!HasOverflow) { + if (LHS.isNegative() == RHS.isNegative()) + LeadZ = Result.countLeadingZeros(); + // Do not set leading ones unless the result is known to be non-zero. + else if (LHS.isNonZero() && RHS.isNonZero()) + LeadO = (-Result).countLeadingOnes(); + } // The result of the bottom bits of an integer multiply can be // inferred by looking at the bottom bits of both operands and @@ -873,8 +879,9 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS, KnownBits Res(BitWidth); Res.Zero.setHighBits(LeadZ); + Res.One.setHighBits(LeadO); Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown); - Res.One = BottomKnown.getLoBits(ResultBitsKnown); + Res.One |= BottomKnown.getLoBits(ResultBitsKnown); // If we're self-multiplying then bit[1] is guaranteed to be zero. if (NoUndefSelfMultiply && BitWidth > 1) { diff --git a/llvm/test/Analysis/ValueTracking/knownbits-mul.ll b/llvm/test/Analysis/ValueTracking/knownbits-mul.ll new file mode 100644 index 0000000000000..37526c67f0d9e --- /dev/null +++ b/llvm/test/Analysis/ValueTracking/knownbits-mul.ll @@ -0,0 +1,143 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt < %s -passes=instcombine -S | FileCheck %s + +define i8 @mul_low_bits_know(i8 %xx, i8 %yy) { +; CHECK-LABEL: define i8 @mul_low_bits_know( +; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) { +; CHECK-NEXT: ret i8 0 +; + %x = and i8 %xx, 2 + %y = and i8 %yy, 4 + %mul = mul i8 %x, %y + %r = and i8 %mul, 6 + ret i8 %r +} + +define i8 @mul_low_bits_know2(i8 %xx, i8 %yy) { +; CHECK-LABEL: define i8 @mul_low_bits_know2( +; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) { +; CHECK-NEXT: ret i8 0 +; + %x = or i8 %xx, -2 + %y = and i8 %yy, 4 + %mul = mul i8 %x, %y + %r = and i8 %mul, 2 + ret i8 %r +} + +define i8 @mul_low_bits_partially_known(i8 %xx, i8 %yy) { +; CHECK-LABEL: define i8 @mul_low_bits_partially_known( +; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) { +; CHECK-NEXT: [[Y:%.*]] = or i8 [[YY]], 2 +; CHECK-NEXT: [[MUL:%.*]] = sub nsw i8 0, [[Y]] +; CHECK-NEXT: [[R:%.*]] = and i8 [[MUL]], 2 +; CHECK-NEXT: ret i8 [[R]] +; + %x = or i8 %xx, -4 + %x.notsmin = or i8 %x, 3 + %y = or i8 %yy, -2 + %mul = mul i8 %x.notsmin, %y + %r = and i8 %mul, 6 + ret i8 %r +} + +define i8 @mul_low_bits_unknown(i8 %xx, i8 %yy) { +; CHECK-LABEL: define i8 @mul_low_bits_unknown( +; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) { +; CHECK-NEXT: [[X:%.*]] = or i8 [[XX]], 4 +; CHECK-NEXT: [[Y:%.*]] = or i8 [[YY]], 6 +; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[X]], [[Y]] +; CHECK-NEXT: [[R:%.*]] = and i8 [[MUL]], 6 +; CHECK-NEXT: ret i8 [[R]] +; + %x = or i8 %xx, -4 + %y = or i8 %yy, -2 + %mul = mul i8 %x, %y + %r = and i8 %mul, 6 + ret i8 %r +} + +define i8 @mul_high_bits_know(i8 %xx, i8 %yy) { +; CHECK-LABEL: define i8 @mul_high_bits_know( +; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) { +; CHECK-NEXT: ret i8 0 +; + %x = and i8 %xx, 2 + %y = and i8 %yy, 4 + %mul = mul i8 %x, %y + %r = and i8 %mul, 16 + ret i8 %r +} + +define i8 @mul_high_bits_know2(i8 %xx, i8 %yy) { +; CHECK-LABEL: define i8 @mul_high_bits_know2( +; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) { +; CHECK-NEXT: ret i8 -16 +; + %x = or i8 %xx, -2 + %y = and i8 %yy, 4 + %y.nonzero = or i8 %y, 1 + %mul = mul i8 %x, %y.nonzero + %r = and i8 %mul, -16 + ret i8 %r +} + +define i8 @mul_high_bits_know3(i8 %xx, i8 %yy) { +; CHECK-LABEL: define i8 @mul_high_bits_know3( +; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) { +; CHECK-NEXT: ret i8 0 +; + %x = or i8 %xx, -4 + %y = or i8 %yy, -2 + %mul = mul i8 %x, %y + %r = and i8 %mul, -16 + ret i8 %r +} + +define i8 @mul_high_bits_unknown(i8 %xx, i8 %yy) { +; CHECK-LABEL: define i8 @mul_high_bits_unknown( +; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) { +; CHECK-NEXT: [[X:%.*]] = and i8 [[XX]], 2 +; CHECK-NEXT: [[Y:%.*]] = and i8 [[YY]], 4 +; CHECK-NEXT: [[MUL:%.*]] = mul nuw nsw i8 [[X]], [[Y]] +; CHECK-NEXT: ret i8 [[MUL]] +; + %x = and i8 %xx, 2 + %y = and i8 %yy, 4 + %mul = mul i8 %x, %y + %r = and i8 %mul, 8 + ret i8 %r +} + +define i8 @mul_high_bits_unknown2(i8 %xx, i8 %yy) { +; CHECK-LABEL: define i8 @mul_high_bits_unknown2( +; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) { +; CHECK-NEXT: [[X:%.*]] = or i8 [[XX]], -2 +; CHECK-NEXT: [[Y:%.*]] = and i8 [[YY]], 4 +; CHECK-NEXT: [[MUL:%.*]] = mul nsw i8 [[X]], [[Y]] +; CHECK-NEXT: [[R:%.*]] = and i8 [[MUL]], -16 +; CHECK-NEXT: ret i8 [[R]] +; + %x = or i8 %xx, -2 + %y = and i8 %yy, 4 + %mul = mul i8 %x, %y + %r = and i8 %mul, -16 + ret i8 %r +} + +; TODO: This can be reduced to zero. +define i8 @mul_high_bits_unknown3(i8 %xx, i8 %yy) { +; CHECK-LABEL: define i8 @mul_high_bits_unknown3( +; CHECK-SAME: i8 [[XX:%.*]], i8 [[YY:%.*]]) { +; CHECK-NEXT: [[X:%.*]] = or i8 [[XX]], 28 +; CHECK-NEXT: [[Y:%.*]] = or i8 [[YY]], 30 +; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[X]], [[Y]] +; CHECK-NEXT: [[R:%.*]] = and i8 [[MUL]], 16 +; CHECK-NEXT: ret i8 [[R]] +; + %x = or i8 %xx, -4 + %y = or i8 %yy, -2 + %mul = mul i8 %x, %y + %r = and i8 %mul, 16 + ret i8 %r +} diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp index b16368de17648..e374b46492622 100644 --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -815,7 +815,7 @@ TEST(KnownBitsTest, ConcatBits) { } } -TEST(KnownBitsTest, MulExhaustive) { +TEST(KnownBitsTest, MulLowBitsExhaustive) { for (unsigned Bits : {1, 4}) { ForeachKnownBits(Bits, [&](const KnownBits &Known1) { ForeachKnownBits(Bits, [&](const KnownBits &Known2) { @@ -849,4 +849,54 @@ TEST(KnownBitsTest, MulExhaustive) { } } +TEST(KnownBitsTest, MulHighBits) { + unsigned Bits = 8; + SmallVector, 4> TestPairs = { + {2, 4}, {-2, -4}, {2, -4}, {-2, 4}}; + for (auto [K1, K2] : TestPairs) { + KnownBits Known1(Bits), Known2(Bits); + if (K1 > 0) { + // If we only set the zeros of ~K1, Known1 could be zero. Avoid this case, + // as we can only set leading ones in the case where LHS and RHS have + // different signs, when the result is known non-zero. + Known1.Zero |= ~(K1 | 1); + Known1.One |= 1; + } else { + Known1.One |= K1; + } + if (K2 > 0) { + // If we only set the zeros of ~K1, Known1 could be zero. Avoid this case, + // as we can only set leading ones in the case where LHS and RHS have + // different signs, when the result is known non-zero. + Known2.Zero |= ~(K2 | 1); + Known2.One |= 1; + } else { + Known2.One |= K2; + } + KnownBits Computed = KnownBits::mul(Known1, Known2); + KnownBits Exact(Bits); + Exact.Zero.setAllBits(); + Exact.One.setAllBits(); + + ForeachNumInKnownBits(Known1, [&](const APInt &N1) { + ForeachNumInKnownBits(Known2, [&](const APInt &N2) { + APInt Res = N1 * N2; + Exact.One &= Res; + Exact.Zero &= ~Res; + }); + }); + + // Check that the high bits are optimal, with the caveat that mul_ov of LHS + // and RHS doesn't overflow, which is the case for our TestPairs. + APInt Mask = APInt::getHighBitsSet( + Bits, (Exact.Zero | Exact.One).countLeadingOnes()); + Exact.Zero &= Mask; + Exact.One &= Mask; + Computed.Zero &= Mask; + Computed.One &= Mask; + EXPECT_TRUE(checkResult("mul", Exact, Computed, {Known1, Known2}, + /*CheckOptimality=*/true)); + } +} + } // end anonymous namespace