@@ -427,29 +427,16 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
427427 // Check if both operands are the same sign-extension of a single value.
428428 const Value *A = nullptr ;
429429 if (match (Op0, m_SExt (m_Value (A))) && match (Op1, m_SExt (m_Specific (A)))) {
430- // Product of (sext x) * (sext x) is always non-negative.
431- // Compute the maximum possible square and fold all out-of-range bits.
432- Type *FromTy = A->getType ();
433- Type *ToTy = Op0->getType ();
434- if (FromTy->isIntegerTy () && ToTy->isIntegerTy () &&
435- FromTy->getScalarSizeInBits () < ToTy->getScalarSizeInBits ()) {
436- unsigned FromBits = FromTy->getScalarSizeInBits ();
437- unsigned ToBits = ToTy->getScalarSizeInBits ();
438- // For signed, the maximum absolute value is max(|min|, |max|)
439- APInt minVal = APInt::getSignedMinValue (FromBits);
440- APInt maxVal = APInt::getSignedMaxValue (FromBits);
441- APInt absMin = minVal.abs ();
442- APInt absMax = maxVal.abs ();
443- APInt maxAbs = absMin.ugt (absMax) ? absMin : absMax;
444- APInt maxSquare = maxAbs.zext (ToBits);
445- maxSquare = maxSquare * maxSquare;
446- // All bits above the highest set bit in maxSquare are known zero.
447- unsigned MaxBit = maxSquare.isZero () ? 0 : maxSquare.logBase2 ();
448- if (MaxBit + 1 < ToBits) {
449- APInt KnownZeroMask =
450- APInt::getHighBitsSet (ToBits, ToBits - (MaxBit + 1 ));
451- Known.Zero |= KnownZeroMask;
452- }
430+ unsigned SignBits = ComputeNumSignBits (Op0, DemandedElts, Q, Depth + 1 );
431+ unsigned TyBits = Op0->getType ()->getScalarSizeInBits ();
432+ // The output of the Mul can be at most twice the valid bits
433+ unsigned OutValidBits = 2 * (TyBits - SignBits + 1 );
434+ unsigned OutSignBits =
435+ OutValidBits > TyBits ? 1 : TyBits - OutValidBits + 1 ;
436+
437+ if (OutSignBits > 1 ) {
438+ APInt KnownZeroMask = APInt::getHighBitsSet (TyBits, OutSignBits);
439+ Known.Zero |= KnownZeroMask;
453440 }
454441 }
455442}
0 commit comments