@@ -423,6 +423,49 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
423423 Known.makeNonNegative ();
424424 else if (isKnownNegative && !Known.isNonNegative ())
425425 Known.makeNegative ();
426+
427+ // Additional logic: If both operands are the same sign- or zero-extended
428+ // value from a small integer, and the multiplication is (sext x) * (sext x)
429+ // or (zext x) * (zext x), then the result cannot set bits above the maximum
430+ // possible square. This allows InstCombine and other passes to fold (x * x) &
431+ // (1 << N) to 0 when N is out of range.
432+ using namespace PatternMatch ;
433+ const Value *A = nullptr ;
434+ // Only handle the case where both operands are the same extension of the same
435+ // value.
436+ if ((match (Op0, m_SExt (m_Value (A))) && match (Op1, m_SExt (m_Specific (A)))) ||
437+ (match (Op0, m_ZExt (m_Value (A))) && match (Op1, m_ZExt (m_Specific (A))))) {
438+ Type *FromTy = A->getType ();
439+ Type *ToTy = Op0->getType ();
440+ if (FromTy->isIntegerTy () && ToTy->isIntegerTy () &&
441+ FromTy->getScalarSizeInBits () < ToTy->getScalarSizeInBits ()) {
442+ unsigned FromBits = FromTy->getScalarSizeInBits ();
443+ unsigned ToBits = ToTy->getScalarSizeInBits ();
444+ // For both signed and unsigned, the maximum absolute value is max(|min|,
445+ // |max|)
446+ APInt minVal (FromBits, 0 ), maxVal (FromBits, 0 );
447+ bool isSigned = isa<SExtInst>(Op0);
448+ if (isSigned) {
449+ minVal = APInt::getSignedMinValue (FromBits);
450+ maxVal = APInt::getSignedMaxValue (FromBits);
451+ } else {
452+ minVal = APInt::getMinValue (FromBits);
453+ maxVal = APInt::getMaxValue (FromBits);
454+ }
455+ APInt absMin = minVal.abs ();
456+ APInt absMax = maxVal.abs ();
457+ APInt maxAbs = absMin.ugt (absMax) ? absMin : absMax;
458+ APInt maxSquare = maxAbs.zext (ToBits);
459+ maxSquare = maxSquare * maxSquare;
460+ // All bits above the highest set bit in maxSquare are known zero.
461+ unsigned MaxBit = maxSquare.isZero () ? 0 : maxSquare.logBase2 ();
462+ if (MaxBit + 1 < ToBits) {
463+ APInt KnownZeroMask =
464+ APInt::getHighBitsSet (ToBits, ToBits - (MaxBit + 1 ));
465+ Known.Zero |= KnownZeroMask;
466+ }
467+ }
468+ }
426469}
427470
428471void llvm::computeKnownBitsFromRangeMetadata (const MDNode &Ranges,
0 commit comments