Skip to content

Commit 71d6762

Browse files
authored
[InstCombine] Added pattern for recognising the construction of packed integers. (#147414)
This patch extends the instruction combiner to simplify the construction of a packed scalar integer from a vector type, such as: ```llvm target datalayout = "e" define i32 @src(<4 x i8> %v) { %v.0 = extractelement <4 x i8> %v, i32 0 %z.0 = zext i8 %v.0 to i32 %v.1 = extractelement <4 x i8> %v, i32 1 %z.1 = zext i8 %v.1 to i32 %s.1 = shl i32 %z.1, 8 %x.1 = or i32 %z.0, %s.1 %v.2 = extractelement <4 x i8> %v, i32 2 %z.2 = zext i8 %v.2 to i32 %s.2 = shl i32 %z.2, 16 %x.2 = or i32 %x.1, %s.2 %v.3 = extractelement <4 x i8> %v, i32 3 %z.3 = zext i8 %v.3 to i32 %s.3 = shl i32 %z.3, 24 %x.3 = or i32 %x.2, %s.3 ret i32 %x.3 } ; =============== define i32 @tgt(<4 x i8> %v) { %x.3 = bitcast <4 x i8> %v to i32 ret i32 %x.3 } ``` Alive2 proofs (little-endian): [YKdMeg](https://alive2.llvm.org/ce/z/YKdMeg) Alive2 proofs (big-endian): [vU6iKc](https://alive2.llvm.org/ce/z/vU6iKc)
1 parent c1968fe commit 71d6762

File tree

2 files changed

+1080
-0
lines changed

2 files changed

+1080
-0
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "InstCombineInternal.h"
14+
#include "llvm/ADT/SmallBitVector.h"
1415
#include "llvm/Analysis/CmpInstAnalysis.h"
1516
#include "llvm/Analysis/FloatingPointPredicateUtils.h"
1617
#include "llvm/Analysis/InstructionSimplify.h"
1718
#include "llvm/IR/ConstantRange.h"
19+
#include "llvm/IR/DerivedTypes.h"
20+
#include "llvm/IR/Instructions.h"
1821
#include "llvm/IR/Intrinsics.h"
1922
#include "llvm/IR/PatternMatch.h"
2023
#include "llvm/Transforms/InstCombine/InstCombiner.h"
@@ -3589,6 +3592,154 @@ static Value *foldOrOfInversions(BinaryOperator &I,
35893592
return nullptr;
35903593
}
35913594

3595+
/// Match \p V as "shufflevector -> bitcast" or "extractelement -> zext -> shl"
3596+
/// patterns, which extract vector elements and pack them in the same relative
3597+
/// positions.
3598+
///
3599+
/// \p Vec is the underlying vector being extracted from.
3600+
/// \p Mask is a bitmask identifying which packed elements are obtained from the
3601+
/// vector.
3602+
/// \p VecOffset is the vector element corresponding to index 0 of the
3603+
/// mask.
3604+
static bool matchSubIntegerPackFromVector(Value *V, Value *&Vec,
3605+
int64_t &VecOffset,
3606+
SmallBitVector &Mask,
3607+
const DataLayout &DL) {
3608+
static const auto m_ConstShlOrSelf = [](const auto &Base, uint64_t &ShlAmt) {
3609+
ShlAmt = 0;
3610+
return m_CombineOr(m_Shl(Base, m_ConstantInt(ShlAmt)), Base);
3611+
};
3612+
3613+
// First try to match extractelement -> zext -> shl
3614+
uint64_t VecIdx, ShlAmt;
3615+
if (match(V, m_ConstShlOrSelf(m_ZExtOrSelf(m_ExtractElt(
3616+
m_Value(Vec), m_ConstantInt(VecIdx))),
3617+
ShlAmt))) {
3618+
auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType());
3619+
if (!VecTy)
3620+
return false;
3621+
auto *EltTy = dyn_cast<IntegerType>(VecTy->getElementType());
3622+
if (!EltTy)
3623+
return false;
3624+
3625+
const unsigned EltBitWidth = EltTy->getBitWidth();
3626+
const unsigned TargetBitWidth = V->getType()->getIntegerBitWidth();
3627+
if (TargetBitWidth % EltBitWidth != 0 || ShlAmt % EltBitWidth != 0)
3628+
return false;
3629+
const unsigned TargetEltWidth = TargetBitWidth / EltBitWidth;
3630+
const unsigned ShlEltAmt = ShlAmt / EltBitWidth;
3631+
3632+
const unsigned MaskIdx =
3633+
DL.isLittleEndian() ? ShlEltAmt : TargetEltWidth - ShlEltAmt - 1;
3634+
3635+
VecOffset = static_cast<int64_t>(VecIdx) - static_cast<int64_t>(MaskIdx);
3636+
Mask.resize(TargetEltWidth);
3637+
Mask.set(MaskIdx);
3638+
return true;
3639+
}
3640+
3641+
// Now try to match a bitcasted subvector.
3642+
Instruction *SrcVecI;
3643+
if (!match(V, m_BitCast(m_Instruction(SrcVecI))))
3644+
return false;
3645+
3646+
auto *SrcTy = dyn_cast<FixedVectorType>(SrcVecI->getType());
3647+
if (!SrcTy)
3648+
return false;
3649+
3650+
Mask.resize(SrcTy->getNumElements());
3651+
3652+
// First check for a subvector obtained from a shufflevector.
3653+
if (isa<ShuffleVectorInst>(SrcVecI)) {
3654+
Constant *ConstVec;
3655+
ArrayRef<int> ShuffleMask;
3656+
if (!match(SrcVecI, m_Shuffle(m_Value(Vec), m_Constant(ConstVec),
3657+
m_Mask(ShuffleMask))))
3658+
return false;
3659+
3660+
auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType());
3661+
if (!VecTy)
3662+
return false;
3663+
3664+
const unsigned NumVecElts = VecTy->getNumElements();
3665+
bool FoundVecOffset = false;
3666+
for (unsigned Idx = 0; Idx < ShuffleMask.size(); ++Idx) {
3667+
if (ShuffleMask[Idx] == PoisonMaskElem)
3668+
return false;
3669+
const unsigned ShuffleIdx = ShuffleMask[Idx];
3670+
if (ShuffleIdx >= NumVecElts) {
3671+
const unsigned ConstIdx = ShuffleIdx - NumVecElts;
3672+
auto *ConstElt =
3673+
dyn_cast<ConstantInt>(ConstVec->getAggregateElement(ConstIdx));
3674+
if (!ConstElt || !ConstElt->isNullValue())
3675+
return false;
3676+
continue;
3677+
}
3678+
3679+
if (FoundVecOffset) {
3680+
if (VecOffset + Idx != ShuffleIdx)
3681+
return false;
3682+
} else {
3683+
if (ShuffleIdx < Idx)
3684+
return false;
3685+
VecOffset = ShuffleIdx - Idx;
3686+
FoundVecOffset = true;
3687+
}
3688+
Mask.set(Idx);
3689+
}
3690+
return FoundVecOffset;
3691+
}
3692+
3693+
// Check for a subvector obtained as an (insertelement V, 0, idx)
3694+
uint64_t InsertIdx;
3695+
if (!match(SrcVecI,
3696+
m_InsertElt(m_Value(Vec), m_Zero(), m_ConstantInt(InsertIdx))))
3697+
return false;
3698+
3699+
auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType());
3700+
if (!VecTy)
3701+
return false;
3702+
VecOffset = 0;
3703+
bool AlreadyInsertedMaskedElt = Mask.test(InsertIdx);
3704+
Mask.set();
3705+
if (!AlreadyInsertedMaskedElt)
3706+
Mask.reset(InsertIdx);
3707+
return true;
3708+
}
3709+
3710+
/// Try to fold the join of two scalar integers whose contents are packed
3711+
/// elements of the same vector.
3712+
static Instruction *foldIntegerPackFromVector(Instruction &I,
3713+
InstCombiner::BuilderTy &Builder,
3714+
const DataLayout &DL) {
3715+
assert(I.getOpcode() == Instruction::Or);
3716+
Value *LhsVec, *RhsVec;
3717+
int64_t LhsVecOffset, RhsVecOffset;
3718+
SmallBitVector Mask;
3719+
if (!matchSubIntegerPackFromVector(I.getOperand(0), LhsVec, LhsVecOffset,
3720+
Mask, DL))
3721+
return nullptr;
3722+
if (!matchSubIntegerPackFromVector(I.getOperand(1), RhsVec, RhsVecOffset,
3723+
Mask, DL))
3724+
return nullptr;
3725+
if (LhsVec != RhsVec || LhsVecOffset != RhsVecOffset)
3726+
return nullptr;
3727+
3728+
// Convert into shufflevector -> bitcast;
3729+
const unsigned ZeroVecIdx =
3730+
cast<FixedVectorType>(LhsVec->getType())->getNumElements();
3731+
SmallVector<int> ShuffleMask(Mask.size(), ZeroVecIdx);
3732+
for (unsigned Idx : Mask.set_bits()) {
3733+
assert(LhsVecOffset + Idx >= 0);
3734+
ShuffleMask[Idx] = LhsVecOffset + Idx;
3735+
}
3736+
3737+
Value *MaskedVec = Builder.CreateShuffleVector(
3738+
LhsVec, Constant::getNullValue(LhsVec->getType()), ShuffleMask,
3739+
I.getName() + ".v");
3740+
return CastInst::Create(Instruction::BitCast, MaskedVec, I.getType());
3741+
}
3742+
35923743
// A decomposition of ((X & Mask) * Factor). The NUW / NSW bools
35933744
// track these properities for preservation. Note that we can decompose
35943745
// equivalent select form of this expression (e.g. (!(X & Mask) ? 0 : Mask *
@@ -3766,6 +3917,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
37663917
if (Instruction *X = foldComplexAndOrPatterns(I, Builder))
37673918
return X;
37683919

3920+
if (Instruction *X = foldIntegerPackFromVector(I, Builder, DL))
3921+
return X;
3922+
37693923
// (A & B) | (C & D) -> A ^ D where A == ~C && B == ~D
37703924
// (A & B) | (C & D) -> A ^ C where A == ~D && B == ~C
37713925
if (Value *V = foldOrOfInversions(I, Builder))

0 commit comments

Comments
 (0)