Skip to content

Commit 2c69a9e

Browse files
authored
[InstCombine] Added pattern for recognising the construction of packed integers. (llvm#3581)
This patch extends the instruction combiner to simplify the construction of a packed scalar integer from a vector type, such as: ```llvm target datalayout = "e" define i32 @src(<4 x i8> %v) { %v.0 = extractelement <4 x i8> %v, i32 0 %z.0 = zext i8 %v.0 to i32 %v.1 = extractelement <4 x i8> %v, i32 1 %z.1 = zext i8 %v.1 to i32 %s.1 = shl i32 %z.1, 8 %x.1 = or i32 %z.0, %s.1 %v.2 = extractelement <4 x i8> %v, i32 2 %z.2 = zext i8 %v.2 to i32 %s.2 = shl i32 %z.2, 16 %x.2 = or i32 %x.1, %s.2 %v.3 = extractelement <4 x i8> %v, i32 3 %z.3 = zext i8 %v.3 to i32 %s.3 = shl i32 %z.3, 24 %x.3 = or i32 %x.2, %s.3 ret i32 %x.3 } ; =============== define i32 @tgt(<4 x i8> %v) { %x.3 = bitcast <4 x i8> %v to i32 ret i32 %x.3 } ``` Alive2 proofs (little-endian): [YKdMeg](https://alive2.llvm.org/ce/z/YKdMeg) Alive2 proofs (big-endian): [vU6iKc](https://alive2.llvm.org/ce/z/vU6iKc)
2 parents 7757f13 + 185ddcf commit 2c69a9e

File tree

2 files changed

+1080
-0
lines changed

2 files changed

+1080
-0
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "InstCombineInternal.h"
14+
#include "llvm/ADT/SmallBitVector.h"
1415
#include "llvm/Analysis/CmpInstAnalysis.h"
1516
#include "llvm/Analysis/InstructionSimplify.h"
1617
#include "llvm/IR/ConstantRange.h"
18+
#include "llvm/IR/DerivedTypes.h"
19+
#include "llvm/IR/Instructions.h"
1720
#include "llvm/IR/Intrinsics.h"
1821
#include "llvm/IR/PatternMatch.h"
1922
#include "llvm/Transforms/InstCombine/InstCombiner.h"
@@ -3546,6 +3549,154 @@ static Value *foldOrOfInversions(BinaryOperator &I,
35463549
return nullptr;
35473550
}
35483551

3552+
/// Match \p V as "shufflevector -> bitcast" or "extractelement -> zext -> shl"
3553+
/// patterns, which extract vector elements and pack them in the same relative
3554+
/// positions.
3555+
///
3556+
/// \p Vec is the underlying vector being extracted from.
3557+
/// \p Mask is a bitmask identifying which packed elements are obtained from the
3558+
/// vector.
3559+
/// \p VecOffset is the vector element corresponding to index 0 of the
3560+
/// mask.
3561+
static bool matchSubIntegerPackFromVector(Value *V, Value *&Vec,
3562+
int64_t &VecOffset,
3563+
SmallBitVector &Mask,
3564+
const DataLayout &DL) {
3565+
static const auto m_ConstShlOrSelf = [](const auto &Base, uint64_t &ShlAmt) {
3566+
ShlAmt = 0;
3567+
return m_CombineOr(m_Shl(Base, m_ConstantInt(ShlAmt)), Base);
3568+
};
3569+
3570+
// First try to match extractelement -> zext -> shl
3571+
uint64_t VecIdx, ShlAmt;
3572+
if (match(V, m_ConstShlOrSelf(m_ZExtOrSelf(m_ExtractElt(
3573+
m_Value(Vec), m_ConstantInt(VecIdx))),
3574+
ShlAmt))) {
3575+
auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType());
3576+
if (!VecTy)
3577+
return false;
3578+
auto *EltTy = dyn_cast<IntegerType>(VecTy->getElementType());
3579+
if (!EltTy)
3580+
return false;
3581+
3582+
const unsigned EltBitWidth = EltTy->getBitWidth();
3583+
const unsigned TargetBitWidth = V->getType()->getIntegerBitWidth();
3584+
if (TargetBitWidth % EltBitWidth != 0 || ShlAmt % EltBitWidth != 0)
3585+
return false;
3586+
const unsigned TargetEltWidth = TargetBitWidth / EltBitWidth;
3587+
const unsigned ShlEltAmt = ShlAmt / EltBitWidth;
3588+
3589+
const unsigned MaskIdx =
3590+
DL.isLittleEndian() ? ShlEltAmt : TargetEltWidth - ShlEltAmt - 1;
3591+
3592+
VecOffset = static_cast<int64_t>(VecIdx) - static_cast<int64_t>(MaskIdx);
3593+
Mask.resize(TargetEltWidth);
3594+
Mask.set(MaskIdx);
3595+
return true;
3596+
}
3597+
3598+
// Now try to match a bitcasted subvector.
3599+
Instruction *SrcVecI;
3600+
if (!match(V, m_BitCast(m_Instruction(SrcVecI))))
3601+
return false;
3602+
3603+
auto *SrcTy = dyn_cast<FixedVectorType>(SrcVecI->getType());
3604+
if (!SrcTy)
3605+
return false;
3606+
3607+
Mask.resize(SrcTy->getNumElements());
3608+
3609+
// First check for a subvector obtained from a shufflevector.
3610+
if (isa<ShuffleVectorInst>(SrcVecI)) {
3611+
Constant *ConstVec;
3612+
ArrayRef<int> ShuffleMask;
3613+
if (!match(SrcVecI, m_Shuffle(m_Value(Vec), m_Constant(ConstVec),
3614+
m_Mask(ShuffleMask))))
3615+
return false;
3616+
3617+
auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType());
3618+
if (!VecTy)
3619+
return false;
3620+
3621+
const unsigned NumVecElts = VecTy->getNumElements();
3622+
bool FoundVecOffset = false;
3623+
for (unsigned Idx = 0; Idx < ShuffleMask.size(); ++Idx) {
3624+
if (ShuffleMask[Idx] == PoisonMaskElem)
3625+
return false;
3626+
const unsigned ShuffleIdx = ShuffleMask[Idx];
3627+
if (ShuffleIdx >= NumVecElts) {
3628+
const unsigned ConstIdx = ShuffleIdx - NumVecElts;
3629+
auto *ConstElt =
3630+
dyn_cast<ConstantInt>(ConstVec->getAggregateElement(ConstIdx));
3631+
if (!ConstElt || !ConstElt->isNullValue())
3632+
return false;
3633+
continue;
3634+
}
3635+
3636+
if (FoundVecOffset) {
3637+
if (VecOffset + Idx != ShuffleIdx)
3638+
return false;
3639+
} else {
3640+
if (ShuffleIdx < Idx)
3641+
return false;
3642+
VecOffset = ShuffleIdx - Idx;
3643+
FoundVecOffset = true;
3644+
}
3645+
Mask.set(Idx);
3646+
}
3647+
return FoundVecOffset;
3648+
}
3649+
3650+
// Check for a subvector obtained as an (insertelement V, 0, idx)
3651+
uint64_t InsertIdx;
3652+
if (!match(SrcVecI,
3653+
m_InsertElt(m_Value(Vec), m_Zero(), m_ConstantInt(InsertIdx))))
3654+
return false;
3655+
3656+
auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType());
3657+
if (!VecTy)
3658+
return false;
3659+
VecOffset = 0;
3660+
bool AlreadyInsertedMaskedElt = Mask.test(InsertIdx);
3661+
Mask.set();
3662+
if (!AlreadyInsertedMaskedElt)
3663+
Mask.reset(InsertIdx);
3664+
return true;
3665+
}
3666+
3667+
/// Try to fold the join of two scalar integers whose contents are packed
3668+
/// elements of the same vector.
3669+
static Instruction *foldIntegerPackFromVector(Instruction &I,
3670+
InstCombiner::BuilderTy &Builder,
3671+
const DataLayout &DL) {
3672+
assert(I.getOpcode() == Instruction::Or);
3673+
Value *LhsVec, *RhsVec;
3674+
int64_t LhsVecOffset, RhsVecOffset;
3675+
SmallBitVector Mask;
3676+
if (!matchSubIntegerPackFromVector(I.getOperand(0), LhsVec, LhsVecOffset,
3677+
Mask, DL))
3678+
return nullptr;
3679+
if (!matchSubIntegerPackFromVector(I.getOperand(1), RhsVec, RhsVecOffset,
3680+
Mask, DL))
3681+
return nullptr;
3682+
if (LhsVec != RhsVec || LhsVecOffset != RhsVecOffset)
3683+
return nullptr;
3684+
3685+
// Convert into shufflevector -> bitcast;
3686+
const unsigned ZeroVecIdx =
3687+
cast<FixedVectorType>(LhsVec->getType())->getNumElements();
3688+
SmallVector<int> ShuffleMask(Mask.size(), ZeroVecIdx);
3689+
for (unsigned Idx : Mask.set_bits()) {
3690+
assert(LhsVecOffset + Idx >= 0);
3691+
ShuffleMask[Idx] = LhsVecOffset + Idx;
3692+
}
3693+
3694+
Value *MaskedVec = Builder.CreateShuffleVector(
3695+
LhsVec, Constant::getNullValue(LhsVec->getType()), ShuffleMask,
3696+
I.getName() + ".v");
3697+
return CastInst::Create(Instruction::BitCast, MaskedVec, I.getType());
3698+
}
3699+
35493700
// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
35503701
// here. We should standardize that construct where it is needed or choose some
35513702
// other way to ensure that commutated variants of patterns are not missed.
@@ -3575,6 +3726,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
35753726
if (Instruction *X = foldComplexAndOrPatterns(I, Builder))
35763727
return X;
35773728

3729+
if (Instruction *X = foldIntegerPackFromVector(I, Builder, DL))
3730+
return X;
3731+
35783732
// (A & B) | (C & D) -> A ^ D where A == ~C && B == ~D
35793733
// (A & B) | (C & D) -> A ^ C where A == ~D && B == ~C
35803734
if (Value *V = foldOrOfInversions(I, Builder))

0 commit comments

Comments
 (0)