Skip to content

Commit 0310170

Browse files
author
Leon Clark
committed
Address comments.
1 parent 6ca4bfa commit 0310170

File tree

1 file changed

+30
-34
lines changed

1 file changed

+30
-34
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3493,37 +3493,35 @@ bool VectorCombine::shrinkLoadForShuffles(Instruction &I) {
34933493
if (!OldLoad || !OldLoad->isSimple())
34943494
return false;
34953495

3496-
auto *VecTy = dyn_cast<FixedVectorType>(OldLoad->getType());
3497-
if (!VecTy)
3496+
auto *OldLoadTy = dyn_cast<FixedVectorType>(OldLoad->getType());
3497+
if (!OldLoadTy)
34983498
return false;
34993499

3500+
unsigned const OldNumElements = OldLoadTy->getNumElements();
3501+
35003502
// Search all uses of load. If all uses are shufflevector instructions, and
35013503
// the second operands are all poison values, find the minimum and maximum
35023504
// indices of the vector elements referenced by all shuffle masks.
35033505
// Otherwise return `std::nullopt`.
35043506
using IndexRange = std::pair<int, int>;
35053507
auto GetIndexRangeInShuffles = [&]() -> std::optional<IndexRange> {
3506-
IndexRange OutputRange = IndexRange(VecTy->getNumElements(), -1);
3508+
IndexRange OutputRange = IndexRange(OldNumElements, -1);
35073509
for (auto &Use : I.uses()) {
35083510
// Ensure all uses match the required pattern.
35093511
User *Shuffle = Use.getUser();
3510-
Value *Op0 = nullptr;
35113512
ArrayRef<int> Mask;
35123513

3513-
if (!match(Shuffle, m_Shuffle(m_Value(Op0), m_Undef(), m_Mask(Mask))))
3514+
if (!match(Shuffle,
3515+
m_Shuffle(m_Specific(OldLoad), m_Undef(), m_Mask(Mask))))
35143516
return std::nullopt;
35153517

35163518
// Ignore shufflevector instructions that have no uses.
35173519
if (Shuffle->use_empty())
35183520
continue;
35193521

35203522
// Find the min and max indices used by the shufflevector instruction.
3521-
FixedVectorType *Op0Ty = cast<FixedVectorType>(Op0->getType());
3522-
int NumElems = static_cast<int>(Op0Ty->getNumElements());
3523-
35243523
for (int Index : Mask) {
3525-
if (Index >= 0) {
3526-
Index %= NumElems;
3524+
if (Index >= 0 && Index < static_cast<int>(OldNumElements)) {
35273525
OutputRange.first = std::min(Index, OutputRange.first);
35283526
OutputRange.second = std::max(Index, OutputRange.second);
35293527
}
@@ -3538,34 +3536,29 @@ bool VectorCombine::shrinkLoadForShuffles(Instruction &I) {
35383536

35393537
// Get the range of vector elements used by shufflevector instructions.
35403538
if (auto Indices = GetIndexRangeInShuffles()) {
3541-
unsigned OldSize = VecTy->getNumElements();
3542-
unsigned NewSize = Indices->second + 1u;
3539+
unsigned const NewNumElements = Indices->second + 1u;
35433540

35443541
// If the range of vector elements is smaller than the full load, attempt
35453542
// to create a smaller load.
3546-
if (NewSize < OldSize) {
3543+
if (NewNumElements < OldNumElements) {
35473544
auto Builder = IRBuilder(&I);
35483545
Builder.SetCurrentDebugLocation(I.getDebugLoc());
35493546

3550-
// Create new load of smaller vector.
3551-
auto *ElemTy = VecTy->getElementType();
3552-
auto *NewVecTy = FixedVectorType::get(ElemTy, NewSize);
3553-
auto *PtrOp = OldLoad->getPointerOperand();
3554-
auto *NewLoad = cast<LoadInst>(
3555-
Builder.CreateAlignedLoad(NewVecTy, PtrOp, OldLoad->getAlign()));
3556-
NewLoad->copyMetadata(I);
3557-
35583547
// Calculate costs of old and new ops.
3559-
auto OldCost = TTI.getMemoryOpCost(
3548+
Type *ElemTy = OldLoadTy->getElementType();
3549+
FixedVectorType *NewLoadTy = FixedVectorType::get(ElemTy, NewNumElements);
3550+
Value *PtrOp = OldLoad->getPointerOperand();
3551+
3552+
InstructionCost OldCost = TTI.getMemoryOpCost(
35603553
Instruction::Load, OldLoad->getType(), OldLoad->getAlign(),
35613554
OldLoad->getPointerAddressSpace(), CostKind);
3562-
auto NewCost = TTI.getMemoryOpCost(
3563-
Instruction::Load, NewLoad->getType(), NewLoad->getAlign(),
3564-
NewLoad->getPointerAddressSpace(), CostKind);
3555+
InstructionCost NewCost = TTI.getMemoryOpCost(
3556+
Instruction::Load, NewLoadTy, OldLoad->getAlign(),
3557+
OldLoad->getPointerAddressSpace(), CostKind);
35653558

35663559
using UseEntry = std::pair<ShuffleVectorInst *, std::vector<int>>;
35673560
auto NewUses = SmallVector<UseEntry, 4u>();
3568-
auto SizeDiff = OldSize - NewSize;
3561+
auto SizeDiff = OldNumElements - NewNumElements;
35693562

35703563
for (auto &Use : I.uses()) {
35713564
auto *Shuffle = cast<ShuffleVectorInst>(Use.getUser());
@@ -3575,19 +3568,22 @@ bool VectorCombine::shrinkLoadForShuffles(Instruction &I) {
35753568
NewUses.push_back({Shuffle, {}});
35763569
auto &NewMask = NewUses.back().second;
35773570
for (auto Index : OldMask)
3578-
NewMask.push_back(Index >= int(OldSize) ? Index - SizeDiff : Index);
3571+
NewMask.push_back(Index >= int(NewNumElements) ? Index - SizeDiff : Index);
35793572

35803573
// Update costs.
3581-
OldCost += TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, VecTy, OldMask,
3582-
CostKind);
3583-
NewCost += TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, NewVecTy,
3574+
OldCost += TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, OldLoadTy,
3575+
OldMask, CostKind);
3576+
NewCost += TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, NewLoadTy,
35843577
NewMask, CostKind);
35853578
}
35863579

3587-
if (OldCost < NewCost || !NewCost.isValid()) {
3588-
NewLoad->eraseFromParent();
3580+
if (OldCost < NewCost || !NewCost.isValid())
35893581
return false;
3590-
}
3582+
3583+
// Create new load of smaller vector.
3584+
auto *NewLoad = cast<LoadInst>(
3585+
Builder.CreateAlignedLoad(NewLoadTy, PtrOp, OldLoad->getAlign()));
3586+
NewLoad->copyMetadata(I);
35913587

35923588
// Replace all uses.
35933589
for (auto &Use : NewUses) {
@@ -3597,7 +3593,7 @@ bool VectorCombine::shrinkLoadForShuffles(Instruction &I) {
35973593
Builder.SetInsertPoint(Shuffle);
35983594
Builder.SetCurrentDebugLocation(Shuffle->getDebugLoc());
35993595
auto *NewShuffle = Builder.CreateShuffleVector(
3600-
NewLoad, PoisonValue::get(NewVecTy), NewMask);
3596+
NewLoad, PoisonValue::get(NewLoadTy), NewMask);
36013597

36023598
replaceValue(*Shuffle, *NewShuffle);
36033599
}

0 commit comments

Comments
 (0)