Skip to content

Commit 76450c9

Browse files
author
Leon Clark
committed
Address comments.
1 parent 016e9a5 commit 76450c9

File tree

1 file changed

+30
-34
lines changed

1 file changed

+30
-34
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3701,37 +3701,35 @@ bool VectorCombine::shrinkLoadForShuffles(Instruction &I) {
37013701
if (!OldLoad || !OldLoad->isSimple())
37023702
return false;
37033703

3704-
auto *VecTy = dyn_cast<FixedVectorType>(OldLoad->getType());
3705-
if (!VecTy)
3704+
auto *OldLoadTy = dyn_cast<FixedVectorType>(OldLoad->getType());
3705+
if (!OldLoadTy)
37063706
return false;
37073707

3708+
unsigned const OldNumElements = OldLoadTy->getNumElements();
3709+
37083710
// Search all uses of load. If all uses are shufflevector instructions, and
37093711
// the second operands are all poison values, find the minimum and maximum
37103712
// indices of the vector elements referenced by all shuffle masks.
37113713
// Otherwise return `std::nullopt`.
37123714
using IndexRange = std::pair<int, int>;
37133715
auto GetIndexRangeInShuffles = [&]() -> std::optional<IndexRange> {
3714-
IndexRange OutputRange = IndexRange(VecTy->getNumElements(), -1);
3716+
IndexRange OutputRange = IndexRange(OldNumElements, -1);
37153717
for (auto &Use : I.uses()) {
37163718
// Ensure all uses match the required pattern.
37173719
User *Shuffle = Use.getUser();
3718-
Value *Op0 = nullptr;
37193720
ArrayRef<int> Mask;
37203721

3721-
if (!match(Shuffle, m_Shuffle(m_Value(Op0), m_Undef(), m_Mask(Mask))))
3722+
if (!match(Shuffle,
3723+
m_Shuffle(m_Specific(OldLoad), m_Undef(), m_Mask(Mask))))
37223724
return std::nullopt;
37233725

37243726
// Ignore shufflevector instructions that have no uses.
37253727
if (Shuffle->use_empty())
37263728
continue;
37273729

37283730
// Find the min and max indices used by the shufflevector instruction.
3729-
FixedVectorType *Op0Ty = cast<FixedVectorType>(Op0->getType());
3730-
int NumElems = static_cast<int>(Op0Ty->getNumElements());
3731-
37323731
for (int Index : Mask) {
3733-
if (Index >= 0) {
3734-
Index %= NumElems;
3732+
if (Index >= 0 && Index < static_cast<int>(OldNumElements)) {
37353733
OutputRange.first = std::min(Index, OutputRange.first);
37363734
OutputRange.second = std::max(Index, OutputRange.second);
37373735
}
@@ -3746,34 +3744,29 @@ bool VectorCombine::shrinkLoadForShuffles(Instruction &I) {
37463744

37473745
// Get the range of vector elements used by shufflevector instructions.
37483746
if (auto Indices = GetIndexRangeInShuffles()) {
3749-
unsigned OldSize = VecTy->getNumElements();
3750-
unsigned NewSize = Indices->second + 1u;
3747+
unsigned const NewNumElements = Indices->second + 1u;
37513748

37523749
// If the range of vector elements is smaller than the full load, attempt
37533750
// to create a smaller load.
3754-
if (NewSize < OldSize) {
3751+
if (NewNumElements < OldNumElements) {
37553752
auto Builder = IRBuilder(&I);
37563753
Builder.SetCurrentDebugLocation(I.getDebugLoc());
37573754

3758-
// Create new load of smaller vector.
3759-
auto *ElemTy = VecTy->getElementType();
3760-
auto *NewVecTy = FixedVectorType::get(ElemTy, NewSize);
3761-
auto *PtrOp = OldLoad->getPointerOperand();
3762-
auto *NewLoad = cast<LoadInst>(
3763-
Builder.CreateAlignedLoad(NewVecTy, PtrOp, OldLoad->getAlign()));
3764-
NewLoad->copyMetadata(I);
3765-
37663755
// Calculate costs of old and new ops.
3767-
auto OldCost = TTI.getMemoryOpCost(
3756+
Type *ElemTy = OldLoadTy->getElementType();
3757+
FixedVectorType *NewLoadTy = FixedVectorType::get(ElemTy, NewNumElements);
3758+
Value *PtrOp = OldLoad->getPointerOperand();
3759+
3760+
InstructionCost OldCost = TTI.getMemoryOpCost(
37683761
Instruction::Load, OldLoad->getType(), OldLoad->getAlign(),
37693762
OldLoad->getPointerAddressSpace(), CostKind);
3770-
auto NewCost = TTI.getMemoryOpCost(
3771-
Instruction::Load, NewLoad->getType(), NewLoad->getAlign(),
3772-
NewLoad->getPointerAddressSpace(), CostKind);
3763+
InstructionCost NewCost = TTI.getMemoryOpCost(
3764+
Instruction::Load, NewLoadTy, OldLoad->getAlign(),
3765+
OldLoad->getPointerAddressSpace(), CostKind);
37733766

37743767
using UseEntry = std::pair<ShuffleVectorInst *, std::vector<int>>;
37753768
auto NewUses = SmallVector<UseEntry, 4u>();
3776-
auto SizeDiff = OldSize - NewSize;
3769+
auto SizeDiff = OldNumElements - NewNumElements;
37773770

37783771
for (auto &Use : I.uses()) {
37793772
auto *Shuffle = cast<ShuffleVectorInst>(Use.getUser());
@@ -3783,19 +3776,22 @@ bool VectorCombine::shrinkLoadForShuffles(Instruction &I) {
37833776
NewUses.push_back({Shuffle, {}});
37843777
auto &NewMask = NewUses.back().second;
37853778
for (auto Index : OldMask)
3786-
NewMask.push_back(Index >= int(OldSize) ? Index - SizeDiff : Index);
3779+
NewMask.push_back(Index >= int(NewNumElements) ? Index - SizeDiff : Index);
37873780

37883781
// Update costs.
3789-
OldCost += TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, VecTy, OldMask,
3790-
CostKind);
3791-
NewCost += TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, NewVecTy,
3782+
OldCost += TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, OldLoadTy,
3783+
OldMask, CostKind);
3784+
NewCost += TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, NewLoadTy,
37923785
NewMask, CostKind);
37933786
}
37943787

3795-
if (OldCost < NewCost || !NewCost.isValid()) {
3796-
NewLoad->eraseFromParent();
3788+
if (OldCost < NewCost || !NewCost.isValid())
37973789
return false;
3798-
}
3790+
3791+
// Create new load of smaller vector.
3792+
auto *NewLoad = cast<LoadInst>(
3793+
Builder.CreateAlignedLoad(NewLoadTy, PtrOp, OldLoad->getAlign()));
3794+
NewLoad->copyMetadata(I);
37993795

38003796
// Replace all uses.
38013797
for (auto &Use : NewUses) {
@@ -3805,7 +3801,7 @@ bool VectorCombine::shrinkLoadForShuffles(Instruction &I) {
38053801
Builder.SetInsertPoint(Shuffle);
38063802
Builder.SetCurrentDebugLocation(Shuffle->getDebugLoc());
38073803
auto *NewShuffle = Builder.CreateShuffleVector(
3808-
NewLoad, PoisonValue::get(NewVecTy), NewMask);
3804+
NewLoad, PoisonValue::get(NewLoadTy), NewMask);
38093805

38103806
replaceValue(*Shuffle, *NewShuffle);
38113807
}

0 commit comments

Comments
 (0)