Skip to content

Commit 98e536c

Browse files
committed
Added additional logic to fold (x * x) masks for out-of-range bits
1 parent 59fdd97 commit 98e536c

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,49 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
423423
Known.makeNonNegative();
424424
else if (isKnownNegative && !Known.isNonNegative())
425425
Known.makeNegative();
426+
427+
// Additional logic: If both operands are the same sign- or zero-extended
428+
// value from a small integer, and the multiplication is (sext x) * (sext x)
429+
// or (zext x) * (zext x), then the result cannot set bits above the maximum
430+
// possible square. This allows InstCombine and other passes to fold (x * x) &
431+
// (1 << N) to 0 when N is out of range.
432+
using namespace PatternMatch;
433+
const Value *A = nullptr;
434+
// Only handle the case where both operands are the same extension of the same
435+
// value.
436+
if ((match(Op0, m_SExt(m_Value(A))) && match(Op1, m_SExt(m_Specific(A)))) ||
437+
(match(Op0, m_ZExt(m_Value(A))) && match(Op1, m_ZExt(m_Specific(A))))) {
438+
Type *FromTy = A->getType();
439+
Type *ToTy = Op0->getType();
440+
if (FromTy->isIntegerTy() && ToTy->isIntegerTy() &&
441+
FromTy->getScalarSizeInBits() < ToTy->getScalarSizeInBits()) {
442+
unsigned FromBits = FromTy->getScalarSizeInBits();
443+
unsigned ToBits = ToTy->getScalarSizeInBits();
444+
// For both signed and unsigned, the maximum absolute value is max(|min|,
445+
// |max|)
446+
APInt minVal(FromBits, 0), maxVal(FromBits, 0);
447+
bool isSigned = isa<SExtInst>(Op0);
448+
if (isSigned) {
449+
minVal = APInt::getSignedMinValue(FromBits);
450+
maxVal = APInt::getSignedMaxValue(FromBits);
451+
} else {
452+
minVal = APInt::getMinValue(FromBits);
453+
maxVal = APInt::getMaxValue(FromBits);
454+
}
455+
APInt absMin = minVal.abs();
456+
APInt absMax = maxVal.abs();
457+
APInt maxAbs = absMin.ugt(absMax) ? absMin : absMax;
458+
APInt maxSquare = maxAbs.zext(ToBits);
459+
maxSquare = maxSquare * maxSquare;
460+
// All bits above the highest set bit in maxSquare are known zero.
461+
unsigned MaxBit = maxSquare.isZero() ? 0 : maxSquare.logBase2();
462+
if (MaxBit + 1 < ToBits) {
463+
APInt KnownZeroMask =
464+
APInt::getHighBitsSet(ToBits, ToBits - (MaxBit + 1));
465+
Known.Zero |= KnownZeroMask;
466+
}
467+
}
468+
}
426469
}
427470

428471
void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges,

llvm/test/Analysis/ValueTracking/known-bits.ll

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,36 @@ define i1 @vec_reverse_known_bits_demanded_fail(<4 x i8> %xx) {
4949
%r = icmp slt i8 %ele, 0
5050
ret i1 %r
5151
}
52+
53+
; Test known bits for (sext i8 x) * (sext i8 x)
54+
; RUN: opt -passes=instcombine < %s -S | FileCheck %s --check-prefix=SEXT_SQUARE
55+
56+
define i1 @sext_square_bit31(i8 %x) {
57+
; SEXT_SQUARE-LABEL: @sext_square_bit31(
58+
; SEXT_SQUARE-NEXT: ret i1 false
59+
%sx = sext i8 %x to i32
60+
%mul = mul nsw i32 %sx, %sx
61+
%and = and i32 %mul, 2147483648 ; 1 << 31
62+
%cmp = icmp ne i32 %and, 0
63+
ret i1 %cmp
64+
}
65+
66+
define i1 @sext_square_bit30(i8 %x) {
67+
; SEXT_SQUARE-LABEL: @sext_square_bit30(
68+
; SEXT_SQUARE-NEXT: ret i1 false
69+
%sx = sext i8 %x to i32
70+
%mul = mul nsw i32 %sx, %sx
71+
%and = and i32 %mul, 1073741824 ; 1 << 30
72+
%cmp = icmp ne i32 %and, 0
73+
ret i1 %cmp
74+
}
75+
76+
define i1 @sext_square_bit14(i8 %x) {
77+
; SEXT_SQUARE-LABEL: @sext_square_bit14(
78+
; SEXT_SQUARE-NOT: ret i1 false
79+
%sx = sext i8 %x to i32
80+
%mul = mul nsw i32 %sx, %sx
81+
%and = and i32 %mul, 16384 ; 1 << 14
82+
%cmp = icmp ne i32 %and, 0
83+
ret i1 %cmp
84+
}

0 commit comments

Comments
 (0)