Skip to content

Commit fdace1c

Browse files
authored
[SLP][NFC]Extract SCEVExpander from calculateRtStride, NFC
Make `calculateRtStride` return the SCEV of rt stride value and let the caller expand it where needed.
1 parent 48d445a commit fdace1c

File tree

1 file changed

+25
-31
lines changed

1 file changed

+25
-31
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6323,17 +6323,11 @@ static bool isReverseOrder(ArrayRef<unsigned> Order) {
63236323
}
63246324

63256325
/// Checks if the provided list of pointers \p Pointers represents the strided
6326-
/// pointers for type ElemTy. If they are not, std::nullopt is returned.
6327-
/// Otherwise, if \p Inst is not specified, just initialized optional value is
6328-
/// returned to show that the pointers represent strided pointers. If \p Inst
6329-
/// specified, the runtime stride is materialized before the given \p Inst.
6330-
/// \returns std::nullopt if the pointers are not pointers with the runtime
6331-
/// stride, nullptr or actual stride value, otherwise.
6332-
static std::optional<Value *>
6333-
calculateRtStride(ArrayRef<Value *> PointerOps, Type *ElemTy,
6334-
const DataLayout &DL, ScalarEvolution &SE,
6335-
SmallVectorImpl<unsigned> &SortedIndices,
6336-
Instruction *Inst = nullptr) {
6326+
/// pointers for type ElemTy. If they are not, nullptr is returned.
6327+
/// Otherwise, SCEV* of the stride value is returned.
6328+
static const SCEV *calculateRtStride(ArrayRef<Value *> PointerOps, Type *ElemTy,
6329+
const DataLayout &DL, ScalarEvolution &SE,
6330+
SmallVectorImpl<unsigned> &SortedIndices) {
63376331
SmallVector<const SCEV *> SCEVs;
63386332
const SCEV *PtrSCEVLowest = nullptr;
63396333
const SCEV *PtrSCEVHighest = nullptr;
@@ -6342,22 +6336,22 @@ calculateRtStride(ArrayRef<Value *> PointerOps, Type *ElemTy,
63426336
for (Value *Ptr : PointerOps) {
63436337
const SCEV *PtrSCEV = SE.getSCEV(Ptr);
63446338
if (!PtrSCEV)
6345-
return std::nullopt;
6339+
return nullptr;
63466340
SCEVs.push_back(PtrSCEV);
63476341
if (!PtrSCEVLowest && !PtrSCEVHighest) {
63486342
PtrSCEVLowest = PtrSCEVHighest = PtrSCEV;
63496343
continue;
63506344
}
63516345
const SCEV *Diff = SE.getMinusSCEV(PtrSCEV, PtrSCEVLowest);
63526346
if (isa<SCEVCouldNotCompute>(Diff))
6353-
return std::nullopt;
6347+
return nullptr;
63546348
if (Diff->isNonConstantNegative()) {
63556349
PtrSCEVLowest = PtrSCEV;
63566350
continue;
63576351
}
63586352
const SCEV *Diff1 = SE.getMinusSCEV(PtrSCEVHighest, PtrSCEV);
63596353
if (isa<SCEVCouldNotCompute>(Diff1))
6360-
return std::nullopt;
6354+
return nullptr;
63616355
if (Diff1->isNonConstantNegative()) {
63626356
PtrSCEVHighest = PtrSCEV;
63636357
continue;
@@ -6366,7 +6360,7 @@ calculateRtStride(ArrayRef<Value *> PointerOps, Type *ElemTy,
63666360
// Dist = PtrSCEVHighest - PtrSCEVLowest;
63676361
const SCEV *Dist = SE.getMinusSCEV(PtrSCEVHighest, PtrSCEVLowest);
63686362
if (isa<SCEVCouldNotCompute>(Dist))
6369-
return std::nullopt;
6363+
return nullptr;
63706364
int Size = DL.getTypeStoreSize(ElemTy);
63716365
auto TryGetStride = [&](const SCEV *Dist,
63726366
const SCEV *Multiplier) -> const SCEV * {
@@ -6387,10 +6381,10 @@ calculateRtStride(ArrayRef<Value *> PointerOps, Type *ElemTy,
63876381
const SCEV *Sz = SE.getConstant(Dist->getType(), Size * (SCEVs.size() - 1));
63886382
Stride = TryGetStride(Dist, Sz);
63896383
if (!Stride)
6390-
return std::nullopt;
6384+
return nullptr;
63916385
}
63926386
if (!Stride || isa<SCEVConstant>(Stride))
6393-
return std::nullopt;
6387+
return nullptr;
63946388
// Iterate through all pointers and check if all distances are
63956389
// unique multiple of Stride.
63966390
using DistOrdPair = std::pair<int64_t, int>;
@@ -6404,28 +6398,28 @@ calculateRtStride(ArrayRef<Value *> PointerOps, Type *ElemTy,
64046398
const SCEV *Diff = SE.getMinusSCEV(PtrSCEV, PtrSCEVLowest);
64056399
const SCEV *Coeff = TryGetStride(Diff, Stride);
64066400
if (!Coeff)
6407-
return std::nullopt;
6401+
return nullptr;
64086402
const auto *SC = dyn_cast<SCEVConstant>(Coeff);
64096403
if (!SC || isa<SCEVCouldNotCompute>(SC))
6410-
return std::nullopt;
6404+
return nullptr;
64116405
if (!SE.getMinusSCEV(PtrSCEV, SE.getAddExpr(PtrSCEVLowest,
64126406
SE.getMulExpr(Stride, SC)))
64136407
->isZero())
6414-
return std::nullopt;
6408+
return nullptr;
64156409
Dist = SC->getAPInt().getZExtValue();
64166410
}
64176411
// If the strides are not the same or repeated, we can't vectorize.
64186412
if ((Dist / Size) * Size != Dist || (Dist / Size) >= SCEVs.size())
6419-
return std::nullopt;
6413+
return nullptr;
64206414
auto Res = Offsets.emplace(Dist, Cnt);
64216415
if (!Res.second)
6422-
return std::nullopt;
6416+
return nullptr;
64236417
// Consecutive order if the inserted element is the last one.
64246418
IsConsecutive = IsConsecutive && std::next(Res.first) == Offsets.end();
64256419
++Cnt;
64266420
}
64276421
if (Offsets.size() != SCEVs.size())
6428-
return std::nullopt;
6422+
return nullptr;
64296423
SortedIndices.clear();
64306424
if (!IsConsecutive) {
64316425
// Fill SortedIndices array only if it is non-consecutive.
@@ -6436,10 +6430,7 @@ calculateRtStride(ArrayRef<Value *> PointerOps, Type *ElemTy,
64366430
++Cnt;
64376431
}
64386432
}
6439-
if (!Inst)
6440-
return nullptr;
6441-
SCEVExpander Expander(SE, DL, "strided-load-vec");
6442-
return Expander.expandCodeFor(Stride, Stride->getType(), Inst);
6433+
return Stride;
64436434
}
64446435

64456436
static std::pair<InstructionCost, InstructionCost>
@@ -19520,11 +19511,14 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
1952019511
return cast<LoadInst>(V)->getPointerOperand();
1952119512
});
1952219513
OrdersType Order;
19523-
std::optional<Value *> Stride =
19524-
calculateRtStride(PointerOps, ScalarTy, *DL, *SE, Order,
19525-
&*Builder.GetInsertPoint());
19514+
const SCEV *StrideSCEV =
19515+
calculateRtStride(PointerOps, ScalarTy, *DL, *SE, Order);
19516+
assert(StrideSCEV && "At this point stride should be known");
19517+
SCEVExpander Expander(*SE, *DL, "strided-load-vec");
19518+
Value *Stride = Expander.expandCodeFor(
19519+
StrideSCEV, StrideSCEV->getType(), &*Builder.GetInsertPoint());
1952619520
Value *NewStride =
19527-
Builder.CreateIntCast(*Stride, StrideTy, /*isSigned=*/true);
19521+
Builder.CreateIntCast(Stride, StrideTy, /*isSigned=*/true);
1952819522
StrideVal = Builder.CreateMul(
1952919523
NewStride,
1953019524
ConstantInt::get(

0 commit comments

Comments
 (0)