@@ -426,13 +426,31 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
426426
427427 // Check if both operands are the same sign-extension of a single value.
428428 const Value *A = nullptr ;
429-
430429 if (match (Op0, m_SExt (m_Value (A))) && match (Op1, m_SExt (m_Specific (A)))) {
431430 // Product of (sext x) * (sext x) is always non-negative.
432- // So we know the sign bit itself is zero.
433- unsigned SignBits = ComputeNumSignBits (Op0, Q, Depth);
434- if (SignBits > 1 )
435- Known.Zero .setHighBits (SignBits - 1 );
431+ // Compute the maximum possible square and fold all out-of-range bits.
432+ Type *FromTy = A->getType ();
433+ Type *ToTy = Op0->getType ();
434+ if (FromTy->isIntegerTy () && ToTy->isIntegerTy () &&
435+ FromTy->getScalarSizeInBits () < ToTy->getScalarSizeInBits ()) {
436+ unsigned FromBits = FromTy->getScalarSizeInBits ();
437+ unsigned ToBits = ToTy->getScalarSizeInBits ();
438+ // For signed, the maximum absolute value is max(|min|, |max|)
439+ APInt minVal = APInt::getSignedMinValue (FromBits);
440+ APInt maxVal = APInt::getSignedMaxValue (FromBits);
441+ APInt absMin = minVal.abs ();
442+ APInt absMax = maxVal.abs ();
443+ APInt maxAbs = absMin.ugt (absMax) ? absMin : absMax;
444+ APInt maxSquare = maxAbs.zext (ToBits);
445+ maxSquare = maxSquare * maxSquare;
446+ // All bits above the highest set bit in maxSquare are known zero.
447+ unsigned MaxBit = maxSquare.isZero () ? 0 : maxSquare.logBase2 ();
448+ if (MaxBit + 1 < ToBits) {
449+ APInt KnownZeroMask =
450+ APInt::getHighBitsSet (ToBits, ToBits - (MaxBit + 1 ));
451+ Known.Zero |= KnownZeroMask;
452+ }
453+ }
436454 }
437455}
438456
0 commit comments