Skip to content

Commit 97eb9ed

Browse files
author
Mikhail Gudim
committed
[SLPVectorizer] Change arguments of 'isStridedLoad' (NFC)
This is needed to reduce the diff for the future work on widening strided loads. Also, with this change we'll be able to re-use this for the case when each pointer represents a start of a group of contiguous loads.
1 parent 885cb59 commit 97eb9ed

File tree

1 file changed

+13
-28
lines changed

1 file changed

+13
-28
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2234,10 +2234,9 @@ class BoUpSLP {
22342234
/// TODO: If load combining is allowed in the IR optimizer, this analysis
22352235
/// may not be necessary.
22362236
bool isLoadCombineCandidate(ArrayRef<Value *> Stores) const;
2237-
bool isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
2238-
ArrayRef<unsigned> Order, const TargetTransformInfo &TTI,
2239-
const DataLayout &DL, ScalarEvolution &SE,
2240-
const int64_t Diff, StridedPtrInfo &SPtrInfo) const;
2237+
bool isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
2238+
Align Alignment, int64_t Diff, Value *Ptr0, Value *PtrN,
2239+
StridedPtrInfo &SPtrInfo) const;
22412240

22422241
/// Checks if the given array of loads can be represented as a vectorized,
22432242
/// scatter or just simple gather.
@@ -6817,13 +6816,10 @@ isMaskedLoadCompress(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
68176816
/// 4. Any pointer operand is an instruction with the users outside of the
68186817
/// current graph (for masked gathers extra extractelement instructions
68196818
/// might be required).
6820-
bool BoUpSLP::isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
6821-
ArrayRef<unsigned> Order,
6822-
const TargetTransformInfo &TTI,
6823-
const DataLayout &DL, ScalarEvolution &SE,
6824-
const int64_t Diff,
6825-
StridedPtrInfo &SPtrInfo) const {
6826-
const size_t Sz = VL.size();
6819+
bool BoUpSLP::isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
6820+
Align Alignment, int64_t Diff, Value *Ptr0,
6821+
Value *PtrN, StridedPtrInfo &SPtrInfo) const {
6822+
const size_t Sz = PointerOps.size();
68276823
if (Diff % (Sz - 1) != 0)
68286824
return false;
68296825

@@ -6835,7 +6831,6 @@ bool BoUpSLP::isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
68356831
});
68366832

68376833
const uint64_t AbsoluteDiff = std::abs(Diff);
6838-
Type *ScalarTy = VL.front()->getType();
68396834
auto *VecTy = getWidenedType(ScalarTy, Sz);
68406835
if (IsAnyPointerUsedOutGraph ||
68416836
(AbsoluteDiff > Sz &&
@@ -6846,20 +6841,9 @@ bool BoUpSLP::isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
68466841
int64_t Stride = Diff / static_cast<int64_t>(Sz - 1);
68476842
if (Diff != Stride * static_cast<int64_t>(Sz - 1))
68486843
return false;
6849-
Align Alignment =
6850-
cast<LoadInst>(Order.empty() ? VL.front() : VL[Order.front()])
6851-
->getAlign();
6852-
if (!TTI.isLegalStridedLoadStore(VecTy, Alignment))
6844+
if (!TTI->isLegalStridedLoadStore(VecTy, Alignment))
68536845
return false;
6854-
Value *Ptr0;
6855-
Value *PtrN;
6856-
if (Order.empty()) {
6857-
Ptr0 = PointerOps.front();
6858-
PtrN = PointerOps.back();
6859-
} else {
6860-
Ptr0 = PointerOps[Order.front()];
6861-
PtrN = PointerOps[Order.back()];
6862-
}
6846+
68636847
// Iterate through all pointers and check if all distances are
68646848
// unique multiple of Dist.
68656849
SmallSet<int64_t, 4> Dists;
@@ -6868,14 +6852,14 @@ bool BoUpSLP::isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
68686852
if (Ptr == PtrN)
68696853
Dist = Diff;
68706854
else if (Ptr != Ptr0)
6871-
Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, DL, SE);
6855+
Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE);
68726856
// If the strides are not the same or repeated, we can't
68736857
// vectorize.
68746858
if (((Dist / Stride) * Stride) != Dist || !Dists.insert(Dist).second)
68756859
break;
68766860
}
68776861
if (Dists.size() == Sz) {
6878-
Type *StrideTy = DL.getIndexType(Ptr0->getType());
6862+
Type *StrideTy = DL->getIndexType(Ptr0->getType());
68796863
SPtrInfo.StrideVal = ConstantInt::get(StrideTy, Stride);
68806864
SPtrInfo.Ty = getWidenedType(ScalarTy, Sz);
68816865
return true;
@@ -6964,7 +6948,8 @@ BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads(
69646948
cast<Instruction>(V), UserIgnoreList);
69656949
}))
69666950
return LoadsState::CompressVectorize;
6967-
if (isStridedLoad(VL, PointerOps, Order, *TTI, *DL, *SE, *Diff, SPtrInfo))
6951+
if (isStridedLoad(PointerOps, ScalarTy, CommonAlignment, *Diff, Ptr0, PtrN,
6952+
SPtrInfo))
69686953
return LoadsState::StridedVectorize;
69696954
}
69706955
if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) ||

0 commit comments

Comments
 (0)