Skip to content

Commit ec982fa

Browse files
authored
[SLPVectorizer] Change arguments of 'isStridedLoad' (NFC) (#160401)
This is needed to reduce the diff for the future work on widening strided loads. Also, with this change we'll be able to re-use this for the case when each pointer represents a start of a group of contiguous loads.
1 parent 0e14973 commit ec982fa

File tree

1 file changed

+16
-28
lines changed

1 file changed

+16
-28
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2241,10 +2241,9 @@ class BoUpSLP {
22412241
/// TODO: If load combining is allowed in the IR optimizer, this analysis
22422242
/// may not be necessary.
22432243
bool isLoadCombineCandidate(ArrayRef<Value *> Stores) const;
2244-
bool isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
2245-
ArrayRef<unsigned> Order, const TargetTransformInfo &TTI,
2246-
const DataLayout &DL, ScalarEvolution &SE,
2247-
const int64_t Diff, StridedPtrInfo &SPtrInfo) const;
2244+
bool isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
2245+
Align Alignment, const int64_t Diff, Value *Ptr0,
2246+
Value *PtrN, StridedPtrInfo &SPtrInfo) const;
22482247

22492248
/// Checks if the given array of loads can be represented as a vectorized,
22502249
/// scatter or just simple gather.
@@ -6824,13 +6823,10 @@ isMaskedLoadCompress(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
68246823
/// 4. Any pointer operand is an instruction with the users outside of the
68256824
/// current graph (for masked gathers extra extractelement instructions
68266825
/// might be required).
6827-
bool BoUpSLP::isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
6828-
ArrayRef<unsigned> Order,
6829-
const TargetTransformInfo &TTI,
6830-
const DataLayout &DL, ScalarEvolution &SE,
6831-
const int64_t Diff,
6832-
StridedPtrInfo &SPtrInfo) const {
6833-
const size_t Sz = VL.size();
6826+
bool BoUpSLP::isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
6827+
Align Alignment, const int64_t Diff, Value *Ptr0,
6828+
Value *PtrN, StridedPtrInfo &SPtrInfo) const {
6829+
const size_t Sz = PointerOps.size();
68346830
if (Diff % (Sz - 1) != 0)
68356831
return false;
68366832

@@ -6842,7 +6838,6 @@ bool BoUpSLP::isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
68426838
});
68436839

68446840
const uint64_t AbsoluteDiff = std::abs(Diff);
6845-
Type *ScalarTy = VL.front()->getType();
68466841
auto *VecTy = getWidenedType(ScalarTy, Sz);
68476842
if (IsAnyPointerUsedOutGraph ||
68486843
(AbsoluteDiff > Sz &&
@@ -6853,20 +6848,9 @@ bool BoUpSLP::isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
68536848
int64_t Stride = Diff / static_cast<int64_t>(Sz - 1);
68546849
if (Diff != Stride * static_cast<int64_t>(Sz - 1))
68556850
return false;
6856-
Align Alignment =
6857-
cast<LoadInst>(Order.empty() ? VL.front() : VL[Order.front()])
6858-
->getAlign();
6859-
if (!TTI.isLegalStridedLoadStore(VecTy, Alignment))
6851+
if (!TTI->isLegalStridedLoadStore(VecTy, Alignment))
68606852
return false;
6861-
Value *Ptr0;
6862-
Value *PtrN;
6863-
if (Order.empty()) {
6864-
Ptr0 = PointerOps.front();
6865-
PtrN = PointerOps.back();
6866-
} else {
6867-
Ptr0 = PointerOps[Order.front()];
6868-
PtrN = PointerOps[Order.back()];
6869-
}
6853+
68706854
// Iterate through all pointers and check if all distances are
68716855
// unique multiple of Dist.
68726856
SmallSet<int64_t, 4> Dists;
@@ -6875,14 +6859,14 @@ bool BoUpSLP::isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
68756859
if (Ptr == PtrN)
68766860
Dist = Diff;
68776861
else if (Ptr != Ptr0)
6878-
Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, DL, SE);
6862+
Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE);
68796863
// If the strides are not the same or repeated, we can't
68806864
// vectorize.
68816865
if (((Dist / Stride) * Stride) != Dist || !Dists.insert(Dist).second)
68826866
break;
68836867
}
68846868
if (Dists.size() == Sz) {
6885-
Type *StrideTy = DL.getIndexType(Ptr0->getType());
6869+
Type *StrideTy = DL->getIndexType(Ptr0->getType());
68866870
SPtrInfo.StrideVal = ConstantInt::get(StrideTy, Stride);
68876871
SPtrInfo.Ty = getWidenedType(ScalarTy, Sz);
68886872
return true;
@@ -6971,7 +6955,11 @@ BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads(
69716955
cast<Instruction>(V), UserIgnoreList);
69726956
}))
69736957
return LoadsState::CompressVectorize;
6974-
if (isStridedLoad(VL, PointerOps, Order, *TTI, *DL, *SE, *Diff, SPtrInfo))
6958+
Align Alignment =
6959+
cast<LoadInst>(Order.empty() ? VL.front() : VL[Order.front()])
6960+
->getAlign();
6961+
if (isStridedLoad(PointerOps, ScalarTy, Alignment, *Diff, Ptr0, PtrN,
6962+
SPtrInfo))
69756963
return LoadsState::StridedVectorize;
69766964
}
69776965
if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) ||

0 commit comments

Comments
 (0)