Skip to content
Merged
42 changes: 42 additions & 0 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,48 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
Known.makeNonNegative();
else if (isKnownNegative && !Known.isNonNegative())
Known.makeNegative();

// Additional logic: If both operands are the same sign- or zero-extended
// value from a small integer, and the multiplication is (sext x) * (sext x)
// or (zext x) * (zext x), then the result cannot set bits above the maximum
// possible square. This allows InstCombine and other passes to fold (x * x) &
// (1 << N) to 0 when N is out of range.
const Value *A = nullptr;
// Only handle the case where both operands are the same extension of the same
// value.
if ((match(Op0, m_SExt(m_Value(A))) && match(Op1, m_SExt(m_Specific(A)))) ||
(match(Op0, m_ZExt(m_Value(A))) && match(Op1, m_ZExt(m_Specific(A))))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the zext handling here is useful. This will be handled by the generic code.

For the sext case, we know that the result is non-negative (due to self-multiply) and that we have a certain number of sign bits (due to multiply of sext), so together we know that the sign bits are actually zero bits.

I think the principled thing to do here would be, for self-multiplies, to call ComputeNumSignBits() and then set all those bits to zero.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, the zext is redundant. I’ve updated the code so that for self-multiplies using sext, we now call ComputeNumSignBits() to determine the number of sign bits and mark them as known zero.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, after reviewing the previous commit, how should we call ComputeNumSignBits() and set the corresponding bits to zero? In this function, we only track known bits and don’t explicitly compute the product, so it’s unclear how to determine the exact number of sign bits.

I’ve made another commit that reverts to the previous approach using max/min value boundaries and removed the zext handling for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use the same logic as ComputeNumSignBits:

unsigned OutValidBits =
(TyBits - SignBitsOp0 + 1) + (TyBits - SignBitsOp1 + 1);
return OutValidBits > TyBits ? 1 : TyBits - OutValidBits + 1;

Adjusted for the case where the sign bits are the same for both operands:

      unsigned OutValidBits = 2 * (TyBits - SignBits + 1);
      unsigned OutSignBits = OutValidBits > TyBits ? 1 : TyBits - OutValidBits + 1;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok thanks, i have added this into last commit. One question: currently my code uses match while other parts of this function use Op0 == Op1. Should we only handle the explicit self-multiply case (x * x), or also consider cases where both operands are sign-extensions of the same value?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not necessary to handle sign extensions of the same value, as CSE will convert this into one sign extension used in both operands. So we should use Op0 == Op1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ok, i moved it into the selfmultiply handling instead which uses Op0 == Op1

Type *FromTy = A->getType();
Type *ToTy = Op0->getType();
if (FromTy->isIntegerTy() && ToTy->isIntegerTy() &&
FromTy->getScalarSizeInBits() < ToTy->getScalarSizeInBits()) {
unsigned FromBits = FromTy->getScalarSizeInBits();
unsigned ToBits = ToTy->getScalarSizeInBits();
// For both signed and unsigned, the maximum absolute value is max(|min|,
// |max|)
APInt minVal(FromBits, 0), maxVal(FromBits, 0);
bool isSigned = isa<SExtInst>(Op0);
if (isSigned) {
minVal = APInt::getSignedMinValue(FromBits);
maxVal = APInt::getSignedMaxValue(FromBits);
} else {
minVal = APInt::getMinValue(FromBits);
maxVal = APInt::getMaxValue(FromBits);
}
APInt absMin = minVal.abs();
APInt absMax = maxVal.abs();
APInt maxAbs = absMin.ugt(absMax) ? absMin : absMax;
APInt maxSquare = maxAbs.zext(ToBits);
maxSquare = maxSquare * maxSquare;
// All bits above the highest set bit in maxSquare are known zero.
unsigned MaxBit = maxSquare.isZero() ? 0 : maxSquare.logBase2();
if (MaxBit + 1 < ToBits) {
APInt KnownZeroMask =
APInt::getHighBitsSet(ToBits, ToBits - (MaxBit + 1));
Known.Zero |= KnownZeroMask;
}
}
}
}

void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges,
Expand Down
33 changes: 33 additions & 0 deletions llvm/test/Analysis/ValueTracking/known-bits.ll
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,36 @@ define i1 @vec_reverse_known_bits_demanded_fail(<4 x i8> %xx) {
%r = icmp slt i8 %ele, 0
ret i1 %r
}

; Test known bits for (sext i8 x) * (sext i8 x)
; RUN: opt -passes=instcombine < %s -S | FileCheck %s --check-prefix=SEXT_SQUARE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not add an extra run line to this test. If this does not fold through -passes=instsimplify, then this should be tested inside llvm/test/Transforms/InstCombine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't fold with instsimplify so I moved it into llvm/test/Transforms/InstCombine


define i1 @sext_square_bit31(i8 %x) {
; SEXT_SQUARE-LABEL: @sext_square_bit31(
; SEXT_SQUARE-NEXT: ret i1 false
%sx = sext i8 %x to i32
%mul = mul nsw i32 %sx, %sx
%and = and i32 %mul, 2147483648 ; 1 << 31
%cmp = icmp ne i32 %and, 0
ret i1 %cmp
}

define i1 @sext_square_bit30(i8 %x) {
; SEXT_SQUARE-LABEL: @sext_square_bit30(
; SEXT_SQUARE-NEXT: ret i1 false
%sx = sext i8 %x to i32
%mul = mul nsw i32 %sx, %sx
%and = and i32 %mul, 1073741824 ; 1 << 30
%cmp = icmp ne i32 %and, 0
ret i1 %cmp
}

define i1 @sext_square_bit14(i8 %x) {
; SEXT_SQUARE-LABEL: @sext_square_bit14(
; SEXT_SQUARE-NOT: ret i1 false
%sx = sext i8 %x to i32
%mul = mul nsw i32 %sx, %sx
%and = and i32 %mul, 16384 ; 1 << 14
%cmp = icmp ne i32 %and, 0
ret i1 %cmp
}