Skip to content

Commit 04f1601

Browse files
committed
KnownBits: generalize high-bits of mul to overflows
Make the non-overflow case of KnownBits::mul optimal, and smoothly generalize it to the case when overflow occurs by relying on min-product in addition to max-product, noting that it cannot possibly be optimal unless we also look at the bits in between min-product and max-product.
1 parent 603ec71 commit 04f1601

File tree

2 files changed

+192
-44
lines changed

2 files changed

+192
-44
lines changed

llvm/lib/Support/KnownBits.cpp

Lines changed: 65 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -796,24 +796,75 @@ KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
796796
assert((!NoUndefSelfMultiply || LHS == RHS) &&
797797
"Self multiplication knownbits mismatch");
798798

799-
// Compute the high known-0 or known-1 bits by multiplying the max of each
800-
// side. Conservatively, M active bits * N active bits results in M + N bits
801-
// in the result. But if we know a value is a power-of-2 for example, then
802-
// this computes one more leading zero or one.
799+
// Compute the high known-0 or known-1 bits by multiplying the min and max of
800+
// each side.
803801
APInt MaxLHS = LHS.isNegative() ? LHS.getMinValue().abs() : LHS.getMaxValue(),
804-
MaxRHS = RHS.isNegative() ? RHS.getMinValue().abs() : RHS.getMaxValue();
802+
MaxRHS = RHS.isNegative() ? RHS.getMinValue().abs() : RHS.getMaxValue(),
803+
MinLHS = LHS.isNegative() ? LHS.getMaxValue().abs() : LHS.getMinValue(),
804+
MinRHS = RHS.isNegative() ? RHS.getMaxValue().abs() : RHS.getMinValue();
805805

806-
// For leading zeros or ones in the result to be valid, the max product must
807-
// fit in the bitwidth (it must not overflow).
806+
// If MaxProduct doesn't overflow, it implies that MinProduct also won't
807+
// overflow. However, if MaxProduct overflows, there is no guarantee on the
808+
// MinProduct overflowing.
808809
bool HasOverflow;
809-
APInt Result = MaxLHS.umul_ov(MaxRHS, HasOverflow);
810+
APInt MaxProduct = MaxLHS.umul_ov(MaxRHS, HasOverflow),
811+
MinProduct = MinLHS * MinRHS;
812+
813+
if (LHS.isNegative() != RHS.isNegative()) {
814+
// The unsigned-multiplication wrapped MinProduct and MaxProduct can be
815+
// negated to turn them into the corresponding signed-multiplication
816+
// wrapped values.
817+
MinProduct.negate();
818+
MaxProduct.negate();
819+
820+
// MinProduct < MaxProduct is now MaxProduct < MinProduct.
821+
std::swap(MinProduct, MaxProduct);
822+
}
823+
824+
// Unless both MinProduct and MaxProduct are the same sign, there won't be any
825+
// leading zeros or ones in the result.
810826
unsigned LeadZ = 0, LeadO = 0;
811-
if (!HasOverflow) {
812-
if (LHS.isNegative() == RHS.isNegative())
813-
LeadZ = Result.countLeadingZeros();
814-
// Do not set leading ones unless the result is known to be non-zero.
815-
else if (LHS.isNonZero() && RHS.isNonZero())
816-
LeadO = (-Result).countLeadingOnes();
827+
if (MinProduct.isNegative() == MaxProduct.isNegative()) {
828+
APInt LHSUnknown = (~LHS.Zero & ~LHS.One),
829+
RHSUnknown = (~RHS.Zero & ~RHS.One);
830+
831+
// A product of M active bits * N active bits results in M + N bits in the
832+
// result. If either of the operands is a power of two, the result has one
833+
// less active bit.
834+
auto ProdActiveBits = [](const APInt &A, const APInt &B) -> unsigned {
835+
if (A.isZero() || B.isZero())
836+
return 0;
837+
return A.getActiveBits() + B.getActiveBits() -
838+
(A.isPowerOf2() || B.isPowerOf2());
839+
};
840+
841+
// We want to compute the number of active bits in the difference between
842+
// the non-wrapped max product and non-wrapped min product, but we want to
843+
// avoid camputing the non-wrapped max/min product.
844+
unsigned ActiveBitsInDiff;
845+
if (MinLHS.isZero() && MinRHS.isZero())
846+
ActiveBitsInDiff = ProdActiveBits(LHSUnknown, RHSUnknown);
847+
else
848+
ActiveBitsInDiff =
849+
ProdActiveBits(MinLHS.isZero() ? LHSUnknown : MinLHS, RHSUnknown) +
850+
ProdActiveBits(MinRHS.isZero() ? RHSUnknown : MinRHS, LHSUnknown);
851+
852+
// Checks that A.ugt(B), excluding the degenerate case where A is all-ones
853+
// and B is zero.
854+
auto UgtCheckCorner = [](const APInt &A, const APInt &B) {
855+
return (!A.isAllOnes() || !B.isZero()) && A.ugt(B);
856+
};
857+
858+
// We uniformly handle the case where there is no max-overflow, in which
859+
// case the high zeros and ones are computed optimally, and where there is,
860+
// but the result shifts at most by BitWidth, in which case the high zeros
861+
// and ones are not computed optimally.
862+
if ((!HasOverflow || ActiveBitsInDiff <= BitWidth) &&
863+
UgtCheckCorner(MaxProduct, MinProduct)) {
864+
// Set the minimum leading zeros or ones from MaxProduct and MinProduct.
865+
LeadZ = MaxProduct.countLeadingZeros();
866+
LeadO = MinProduct.countLeadingOnes();
867+
}
817868
}
818869

819870
// The result of the bottom bits of an integer multiply can be

llvm/unittests/Support/KnownBitsTest.cpp

Lines changed: 127 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -849,29 +849,83 @@ TEST(KnownBitsTest, MulLowBitsExhaustive) {
849849
}
850850
}
851851

852-
TEST(KnownBitsTest, MulHighBits) {
853-
unsigned Bits = 8;
854-
SmallVector<std::pair<int, int>, 4> TestPairs = {
855-
{2, 4}, {-2, -4}, {2, -4}, {-2, 4}};
856-
for (auto [K1, K2] : TestPairs) {
852+
TEST(KnownBitsTest, MulHighBitsNoOverflow) {
853+
for (unsigned Bits : {1, 4}) {
854+
ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
855+
ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
856+
KnownBits Computed = KnownBits::mul(Known1, Known2);
857+
KnownBits Exact(Bits), WideExact(2 * Bits);
858+
Exact.Zero.setAllBits();
859+
Exact.One.setAllBits();
860+
861+
bool HasOverflow;
862+
ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
863+
ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
864+
// The final value of HasOverflow corresponds to the multiplication
865+
// in the last iteration, which is the max product.
866+
APInt Res = N1.umul_ov(N2, HasOverflow);
867+
Exact.One &= Res;
868+
Exact.Zero &= ~Res;
869+
});
870+
});
871+
872+
if (!Exact.hasConflict() && !HasOverflow) {
873+
// Check that leading zeros and leading ones are optimal in the
874+
// result, provided there is no overflow.
875+
APInt ZerosMask =
876+
APInt::getHighBitsSet(Bits, Exact.Zero.countLeadingOnes()),
877+
OnesMask =
878+
APInt::getHighBitsSet(Bits, Exact.One.countLeadingOnes());
879+
880+
KnownBits ExactZeros(Bits), ComputedZeros(Bits);
881+
KnownBits ExactOnes(Bits), ComputedOnes(Bits);
882+
ExactZeros.Zero.setAllBits();
883+
ExactZeros.One.setAllBits();
884+
ExactOnes.Zero.setAllBits();
885+
ExactOnes.One.setAllBits();
886+
887+
ExactZeros.Zero = Exact.Zero & ZerosMask;
888+
ExactZeros.One = Exact.One & ZerosMask;
889+
ComputedZeros.Zero = Computed.Zero & ZerosMask;
890+
ComputedZeros.One = Computed.One & ZerosMask;
891+
EXPECT_TRUE(checkResult("mul", ExactZeros, ComputedZeros,
892+
{Known1, Known2},
893+
/*CheckOptimality=*/true));
894+
895+
ExactOnes.Zero = Exact.Zero & OnesMask;
896+
ExactOnes.One = Exact.One & OnesMask;
897+
ComputedOnes.Zero = Computed.Zero & OnesMask;
898+
ComputedOnes.One = Computed.One & OnesMask;
899+
EXPECT_TRUE(checkResult("mul", ExactOnes, ComputedOnes,
900+
{Known1, Known2},
901+
/*CheckOptimality=*/true));
902+
}
903+
});
904+
});
905+
}
906+
}
907+
908+
TEST(KnownBitsTest, MulHighBitsOverflow) {
909+
unsigned Bits = 4;
910+
using KnownUnknownPair = std::pair<int, int>;
911+
SmallVector<std::pair<KnownUnknownPair, KnownUnknownPair>> TestPairs = {
912+
{{2, 0}, {7, -1}}, // 001?, 0111
913+
{{2, -1}, {10, 0}}, // 0010, 101?
914+
{{9, 2}, {9, 1}}, // 1?01, 10?1
915+
{{5, 1}, {3, 2}}}; // 01?1, 0?11
916+
for (auto [P1, P2] : TestPairs) {
857917
KnownBits Known1(Bits), Known2(Bits);
858-
if (K1 > 0) {
859-
// If we only set the zeros of ~K1, Known1 could be zero. Avoid this case,
860-
// as we can only set leading ones in the case where LHS and RHS have
861-
// different signs, when the result is known non-zero.
862-
Known1.Zero |= ~(K1 | 1);
863-
Known1.One |= 1;
864-
} else {
865-
Known1.One |= K1;
918+
auto [K1, U1] = P1;
919+
auto [K2, U2] = P2;
920+
Known1 = KnownBits::makeConstant(APInt(Bits, K1));
921+
Known2 = KnownBits::makeConstant(APInt(Bits, K2));
922+
if (U1 > -1) {
923+
Known1.Zero.setBitVal(U1, 0);
924+
Known1.One.setBitVal(U1, 0);
866925
}
867-
if (K2 > 0) {
868-
// If we only set the zeros of ~K1, Known1 could be zero. Avoid this case,
869-
// as we can only set leading ones in the case where LHS and RHS have
870-
// different signs, when the result is known non-zero.
871-
Known2.Zero |= ~(K2 | 1);
872-
Known2.One |= 1;
873-
} else {
874-
Known2.One |= K2;
926+
if (U2 > -1) {
927+
Known2.Zero.setBitVal(U2, 0);
928+
Known2.One.setBitVal(U2, 0);
875929
}
876930
KnownBits Computed = KnownBits::mul(Known1, Known2);
877931
KnownBits Exact(Bits);
@@ -886,17 +940,60 @@ TEST(KnownBitsTest, MulHighBits) {
886940
});
887941
});
888942

889-
// Check that the high bits are optimal, with the caveat that mul_ov of LHS
890-
// and RHS doesn't overflow, which is the case for our TestPairs.
891-
APInt Mask = APInt::getHighBitsSet(
892-
Bits, (Exact.Zero | Exact.One).countLeadingOnes());
893-
Exact.Zero &= Mask;
894-
Exact.One &= Mask;
895-
Computed.Zero &= Mask;
896-
Computed.One &= Mask;
897-
EXPECT_TRUE(checkResult("mul", Exact, Computed, {Known1, Known2},
943+
// Check that the leading zeros or ones are optimal for the given examples,
944+
// which overflow. It is certainly sub-optimal on other examples.
945+
APInt ZerosMask =
946+
APInt::getHighBitsSet(Bits, Exact.Zero.countLeadingOnes()),
947+
OnesMask = APInt::getHighBitsSet(Bits, Exact.One.countLeadingOnes());
948+
949+
KnownBits ExactZeros(Bits), ComputedZeros(Bits);
950+
KnownBits ExactOnes(Bits), ComputedOnes(Bits);
951+
ExactZeros.Zero.setAllBits();
952+
ExactZeros.One.setAllBits();
953+
ExactOnes.Zero.setAllBits();
954+
ExactOnes.One.setAllBits();
955+
956+
ExactZeros.Zero = Exact.Zero & ZerosMask;
957+
ExactZeros.One = Exact.One & ZerosMask;
958+
ComputedZeros.Zero = Computed.Zero & ZerosMask;
959+
ComputedZeros.One = Computed.One & ZerosMask;
960+
EXPECT_TRUE(checkResult("mul", ExactZeros, ComputedZeros, {Known1, Known2},
961+
/*CheckOptimality=*/true));
962+
963+
ExactOnes.Zero = Exact.Zero & OnesMask;
964+
ExactOnes.One = Exact.One & OnesMask;
965+
ComputedOnes.Zero = Computed.Zero & OnesMask;
966+
ComputedOnes.One = Computed.One & OnesMask;
967+
EXPECT_TRUE(checkResult("mul", ExactOnes, ComputedOnes, {Known1, Known2},
898968
/*CheckOptimality=*/true));
899969
}
900970
}
901971

972+
TEST(KnownBitsTest, MulStress) {
973+
// Stress test KnownBits::mul on 5 and 6 bits, checking that the result is
974+
// correct, even if not optimal.
975+
for (unsigned Bits : {5, 6}) {
976+
ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
977+
ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
978+
KnownBits Computed = KnownBits::mul(Known1, Known2);
979+
KnownBits Exact(Bits);
980+
Exact.Zero.setAllBits();
981+
Exact.One.setAllBits();
982+
983+
ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
984+
ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
985+
APInt Res = N1 * N2;
986+
Exact.One &= Res;
987+
Exact.Zero &= ~Res;
988+
});
989+
});
990+
991+
if (!Exact.hasConflict()) {
992+
EXPECT_TRUE(checkResult("mul", Exact, Computed, {Known1, Known2},
993+
/*CheckOptimality=*/false));
994+
}
995+
});
996+
});
997+
}
998+
}
902999
} // end anonymous namespace

0 commit comments

Comments
 (0)