Skip to content

Commit 72cd125

Browse files
committed
Added logic to compute max number of valid and sign bits and set to zero
1 parent 0ff6997 commit 72cd125

File tree

1 file changed

+10
-23
lines changed

1 file changed

+10
-23
lines changed

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -427,29 +427,16 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
427427
// Check if both operands are the same sign-extension of a single value.
428428
const Value *A = nullptr;
429429
if (match(Op0, m_SExt(m_Value(A))) && match(Op1, m_SExt(m_Specific(A)))) {
430-
// Product of (sext x) * (sext x) is always non-negative.
431-
// Compute the maximum possible square and fold all out-of-range bits.
432-
Type *FromTy = A->getType();
433-
Type *ToTy = Op0->getType();
434-
if (FromTy->isIntegerTy() && ToTy->isIntegerTy() &&
435-
FromTy->getScalarSizeInBits() < ToTy->getScalarSizeInBits()) {
436-
unsigned FromBits = FromTy->getScalarSizeInBits();
437-
unsigned ToBits = ToTy->getScalarSizeInBits();
438-
// For signed, the maximum absolute value is max(|min|, |max|)
439-
APInt minVal = APInt::getSignedMinValue(FromBits);
440-
APInt maxVal = APInt::getSignedMaxValue(FromBits);
441-
APInt absMin = minVal.abs();
442-
APInt absMax = maxVal.abs();
443-
APInt maxAbs = absMin.ugt(absMax) ? absMin : absMax;
444-
APInt maxSquare = maxAbs.zext(ToBits);
445-
maxSquare = maxSquare * maxSquare;
446-
// All bits above the highest set bit in maxSquare are known zero.
447-
unsigned MaxBit = maxSquare.isZero() ? 0 : maxSquare.logBase2();
448-
if (MaxBit + 1 < ToBits) {
449-
APInt KnownZeroMask =
450-
APInt::getHighBitsSet(ToBits, ToBits - (MaxBit + 1));
451-
Known.Zero |= KnownZeroMask;
452-
}
430+
unsigned SignBits = ComputeNumSignBits(Op0, DemandedElts, Q, Depth + 1);
431+
unsigned TyBits = Op0->getType()->getScalarSizeInBits();
432+
// The output of the Mul can be at most twice the valid bits
433+
unsigned OutValidBits = 2 * (TyBits - SignBits + 1);
434+
unsigned OutSignBits =
435+
OutValidBits > TyBits ? 1 : TyBits - OutValidBits + 1;
436+
437+
if (OutSignBits > 1) {
438+
APInt KnownZeroMask = APInt::getHighBitsSet(TyBits, OutSignBits);
439+
Known.Zero |= KnownZeroMask;
453440
}
454441
}
455442
}

0 commit comments

Comments
 (0)