Skip to content

Commit 6d81fef

Browse files
committed
Moved logic into separate function.
1 parent 73598e0 commit 6d81fef

File tree

1 file changed

+58
-44
lines changed

1 file changed

+58
-44
lines changed

llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp

Lines changed: 58 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,58 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo,
6060
return true;
6161
}
6262

63+
/// Let N = 2 * M.
64+
/// Given an N-bit integer representing a pack of two M-bit integers,
65+
/// we can select one of the packed integers by right-shifting by either
66+
/// zero or M (which is the most straightforward to check if M is a power
67+
/// of 2), and then isolating the lower M bits. In this case, we can
68+
/// represent the shift as a select on whether the shr amount is nonzero.
69+
static Value *simplifyShiftSelectingPackedElement(Instruction *I,
70+
const APInt &DemandedMask,
71+
InstCombinerImpl &IC,
72+
unsigned Depth) {
73+
assert(I->getOpcode() == Instruction::LShr &&
74+
"Only lshr instruction supported");
75+
76+
uint64_t ShlAmt;
77+
Value *Upper, *Lower;
78+
if (!match(I->getOperand(0),
79+
m_OneUse(m_c_DisjointOr(
80+
m_OneUse(m_Shl(m_Value(Upper), m_ConstantInt(ShlAmt))),
81+
m_Value(Lower)))))
82+
return nullptr;
83+
84+
if (!isPowerOf2_64(ShlAmt))
85+
return nullptr;
86+
87+
const uint64_t DemandedBitWidth = DemandedMask.getActiveBits();
88+
if (DemandedBitWidth > ShlAmt)
89+
return nullptr;
90+
91+
// Check that upper demanded bits are not lost from lshift.
92+
if (Upper->getType()->getScalarSizeInBits() < ShlAmt + DemandedBitWidth)
93+
return nullptr;
94+
95+
KnownBits KnownLowerBits = IC.computeKnownBits(Lower, I, Depth);
96+
if (!KnownLowerBits.getMaxValue().isIntN(ShlAmt))
97+
return nullptr;
98+
99+
Value *ShrAmt = I->getOperand(1);
100+
KnownBits KnownShrBits = IC.computeKnownBits(ShrAmt, I, Depth);
101+
102+
// Verify that ShrAmt is either exactly ShlAmt (which is a power of 2) or
103+
// zero.
104+
if (~KnownShrBits.Zero != ShlAmt)
105+
return nullptr;
106+
107+
Value *ShrAmtZ =
108+
IC.Builder.CreateICmpEQ(ShrAmt, Constant::getNullValue(ShrAmt->getType()),
109+
ShrAmt->getName() + ".z");
110+
Value *Select = IC.Builder.CreateSelect(ShrAmtZ, Lower, Upper);
111+
Select->takeName(I);
112+
return Select;
113+
}
114+
63115
/// Returns the bitwidth of the given scalar or pointer type. For vector types,
64116
/// returns the element type's bitwidth.
65117
static unsigned getBitWidth(Type *Ty, const DataLayout &DL) {
@@ -798,51 +850,13 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Instruction *I,
798850
Known >>= ShiftAmt;
799851
if (ShiftAmt)
800852
Known.Zero.setHighBits(ShiftAmt); // high bits known zero.
801-
} else {
802-
llvm::computeKnownBits(I, Known, Q, Depth);
803-
804-
// Let N = 2 * M.
805-
// Given an N-bit integer representing a pack of two M-bit integers,
806-
// we can select one of the packed integers by right-shifting by either
807-
// zero or M (which is the most straightforward to check if M is a power
808-
// of 2), and then isolating the lower M bits. In this case, we can
809-
// represent the shift as a select on whether the shr amount is nonzero.
810-
uint64_t ShlAmt;
811-
Value *Upper, *Lower;
812-
if (!match(I->getOperand(0),
813-
m_OneUse(m_c_DisjointOr(
814-
m_OneUse(m_Shl(m_Value(Upper), m_ConstantInt(ShlAmt))),
815-
m_Value(Lower)))))
816-
break;
817-
if (!isPowerOf2_64(ShlAmt))
818-
break;
819-
820-
const uint64_t DemandedBitWidth = DemandedMask.getActiveBits();
821-
if (DemandedBitWidth > ShlAmt)
822-
break;
823-
824-
// Check that upper demanded bits are not lost from lshift.
825-
if (Upper->getType()->getScalarSizeInBits() < ShlAmt + DemandedBitWidth)
826-
break;
827-
828-
KnownBits KnownLowerBits = computeKnownBits(Lower, I, Depth);
829-
if (!KnownLowerBits.getMaxValue().isIntN(ShlAmt))
830-
break;
831-
832-
Value *ShrAmt = I->getOperand(1);
833-
KnownBits KnownShrBits = computeKnownBits(ShrAmt, I, Depth);
834-
// Verify that ShrAmt is either exactly ShlAmt (which is a power of 2) or
835-
// zero.
836-
if (~KnownShrBits.Zero != ShlAmt)
837-
break;
838-
839-
Value *ShrAmtZ = Builder.CreateICmpEQ(
840-
ShrAmt, Constant::getNullValue(ShrAmt->getType()),
841-
ShrAmt->getName() + ".z");
842-
Value *Select = Builder.CreateSelect(ShrAmtZ, Lower, Upper);
843-
Select->takeName(I);
844-
return Select;
853+
break;
845854
}
855+
if (Value *V =
856+
simplifyShiftSelectingPackedElement(I, DemandedMask, *this, Depth))
857+
return V;
858+
859+
llvm::computeKnownBits(I, Known, Q, Depth);
846860
break;
847861
}
848862
case Instruction::AShr: {

0 commit comments

Comments
 (0)