Skip to content

Commit d381ca0

Browse files
committed
Cherry-picking 71d6762
1 parent 0e5718a commit d381ca0

File tree

2 files changed

+1080
-0
lines changed

2 files changed

+1080
-0
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "InstCombineInternal.h"
14+
#include "llvm/ADT/SmallBitVector.h"
1415
#include "llvm/Analysis/CmpInstAnalysis.h"
1516
#include "llvm/Analysis/InstructionSimplify.h"
1617
#include "llvm/IR/ConstantRange.h"
18+
#include "llvm/IR/DerivedTypes.h"
19+
#include "llvm/IR/Instructions.h"
1720
#include "llvm/IR/Intrinsics.h"
1821
#include "llvm/IR/PatternMatch.h"
1922
#include "llvm/Transforms/InstCombine/InstCombiner.h"
@@ -3649,6 +3652,154 @@ Value *InstCombinerImpl::reassociateDisjointOr(Value *LHS, Value *RHS) {
36493652
return nullptr;
36503653
}
36513654

3655+
/// Match \p V as "shufflevector -> bitcast" or "extractelement -> zext -> shl"
3656+
/// patterns, which extract vector elements and pack them in the same relative
3657+
/// positions.
3658+
///
3659+
/// \p Vec is the underlying vector being extracted from.
3660+
/// \p Mask is a bitmask identifying which packed elements are obtained from the
3661+
/// vector.
3662+
/// \p VecOffset is the vector element corresponding to index 0 of the
3663+
/// mask.
3664+
static bool matchSubIntegerPackFromVector(Value *V, Value *&Vec,
3665+
int64_t &VecOffset,
3666+
SmallBitVector &Mask,
3667+
const DataLayout &DL) {
3668+
static const auto m_ConstShlOrSelf = [](const auto &Base, uint64_t &ShlAmt) {
3669+
ShlAmt = 0;
3670+
return m_CombineOr(m_Shl(Base, m_ConstantInt(ShlAmt)), Base);
3671+
};
3672+
3673+
// First try to match extractelement -> zext -> shl
3674+
uint64_t VecIdx, ShlAmt;
3675+
if (match(V, m_ConstShlOrSelf(m_ZExtOrSelf(m_ExtractElt(
3676+
m_Value(Vec), m_ConstantInt(VecIdx))),
3677+
ShlAmt))) {
3678+
auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType());
3679+
if (!VecTy)
3680+
return false;
3681+
auto *EltTy = dyn_cast<IntegerType>(VecTy->getElementType());
3682+
if (!EltTy)
3683+
return false;
3684+
3685+
const unsigned EltBitWidth = EltTy->getBitWidth();
3686+
const unsigned TargetBitWidth = V->getType()->getIntegerBitWidth();
3687+
if (TargetBitWidth % EltBitWidth != 0 || ShlAmt % EltBitWidth != 0)
3688+
return false;
3689+
const unsigned TargetEltWidth = TargetBitWidth / EltBitWidth;
3690+
const unsigned ShlEltAmt = ShlAmt / EltBitWidth;
3691+
3692+
const unsigned MaskIdx =
3693+
DL.isLittleEndian() ? ShlEltAmt : TargetEltWidth - ShlEltAmt - 1;
3694+
3695+
VecOffset = static_cast<int64_t>(VecIdx) - static_cast<int64_t>(MaskIdx);
3696+
Mask.resize(TargetEltWidth);
3697+
Mask.set(MaskIdx);
3698+
return true;
3699+
}
3700+
3701+
// Now try to match a bitcasted subvector.
3702+
Instruction *SrcVecI;
3703+
if (!match(V, m_BitCast(m_Instruction(SrcVecI))))
3704+
return false;
3705+
3706+
auto *SrcTy = dyn_cast<FixedVectorType>(SrcVecI->getType());
3707+
if (!SrcTy)
3708+
return false;
3709+
3710+
Mask.resize(SrcTy->getNumElements());
3711+
3712+
// First check for a subvector obtained from a shufflevector.
3713+
if (isa<ShuffleVectorInst>(SrcVecI)) {
3714+
Constant *ConstVec;
3715+
ArrayRef<int> ShuffleMask;
3716+
if (!match(SrcVecI, m_Shuffle(m_Value(Vec), m_Constant(ConstVec),
3717+
m_Mask(ShuffleMask))))
3718+
return false;
3719+
3720+
auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType());
3721+
if (!VecTy)
3722+
return false;
3723+
3724+
const unsigned NumVecElts = VecTy->getNumElements();
3725+
bool FoundVecOffset = false;
3726+
for (unsigned Idx = 0; Idx < ShuffleMask.size(); ++Idx) {
3727+
if (ShuffleMask[Idx] == PoisonMaskElem)
3728+
return false;
3729+
const unsigned ShuffleIdx = ShuffleMask[Idx];
3730+
if (ShuffleIdx >= NumVecElts) {
3731+
const unsigned ConstIdx = ShuffleIdx - NumVecElts;
3732+
auto *ConstElt =
3733+
dyn_cast<ConstantInt>(ConstVec->getAggregateElement(ConstIdx));
3734+
if (!ConstElt || !ConstElt->isNullValue())
3735+
return false;
3736+
continue;
3737+
}
3738+
3739+
if (FoundVecOffset) {
3740+
if (VecOffset + Idx != ShuffleIdx)
3741+
return false;
3742+
} else {
3743+
if (ShuffleIdx < Idx)
3744+
return false;
3745+
VecOffset = ShuffleIdx - Idx;
3746+
FoundVecOffset = true;
3747+
}
3748+
Mask.set(Idx);
3749+
}
3750+
return FoundVecOffset;
3751+
}
3752+
3753+
// Check for a subvector obtained as an (insertelement V, 0, idx)
3754+
uint64_t InsertIdx;
3755+
if (!match(SrcVecI,
3756+
m_InsertElt(m_Value(Vec), m_Zero(), m_ConstantInt(InsertIdx))))
3757+
return false;
3758+
3759+
auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType());
3760+
if (!VecTy)
3761+
return false;
3762+
VecOffset = 0;
3763+
bool AlreadyInsertedMaskedElt = Mask.test(InsertIdx);
3764+
Mask.set();
3765+
if (!AlreadyInsertedMaskedElt)
3766+
Mask.reset(InsertIdx);
3767+
return true;
3768+
}
3769+
3770+
/// Try to fold the join of two scalar integers whose contents are packed
3771+
/// elements of the same vector.
3772+
static Instruction *foldIntegerPackFromVector(Instruction &I,
3773+
InstCombiner::BuilderTy &Builder,
3774+
const DataLayout &DL) {
3775+
assert(I.getOpcode() == Instruction::Or);
3776+
Value *LhsVec, *RhsVec;
3777+
int64_t LhsVecOffset, RhsVecOffset;
3778+
SmallBitVector Mask;
3779+
if (!matchSubIntegerPackFromVector(I.getOperand(0), LhsVec, LhsVecOffset,
3780+
Mask, DL))
3781+
return nullptr;
3782+
if (!matchSubIntegerPackFromVector(I.getOperand(1), RhsVec, RhsVecOffset,
3783+
Mask, DL))
3784+
return nullptr;
3785+
if (LhsVec != RhsVec || LhsVecOffset != RhsVecOffset)
3786+
return nullptr;
3787+
3788+
// Convert into shufflevector -> bitcast;
3789+
const unsigned ZeroVecIdx =
3790+
cast<FixedVectorType>(LhsVec->getType())->getNumElements();
3791+
SmallVector<int> ShuffleMask(Mask.size(), ZeroVecIdx);
3792+
for (unsigned Idx : Mask.set_bits()) {
3793+
assert(LhsVecOffset + Idx >= 0);
3794+
ShuffleMask[Idx] = LhsVecOffset + Idx;
3795+
}
3796+
3797+
Value *MaskedVec = Builder.CreateShuffleVector(
3798+
LhsVec, Constant::getNullValue(LhsVec->getType()), ShuffleMask,
3799+
I.getName() + ".v");
3800+
return CastInst::Create(Instruction::BitCast, MaskedVec, I.getType());
3801+
}
3802+
36523803
// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
36533804
// here. We should standardize that construct where it is needed or choose some
36543805
// other way to ensure that commutated variants of patterns are not missed.
@@ -3678,6 +3829,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
36783829
if (Instruction *X = foldComplexAndOrPatterns(I, Builder))
36793830
return X;
36803831

3832+
if (Instruction *X = foldIntegerPackFromVector(I, Builder, DL))
3833+
return X;
3834+
36813835
// (A & B) | (C & D) -> A ^ D where A == ~C && B == ~D
36823836
// (A & B) | (C & D) -> A ^ C where A == ~D && B == ~C
36833837
if (Value *V = foldOrOfInversions(I, Builder))

0 commit comments

Comments
 (0)