diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 3f18bd70539a0..46f1e60e7d2a9 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -2248,7 +2248,6 @@ class BoUpSLP { /// Return true if an array of scalar loads can be replaced with a strided /// load (with constant stride). /// - /// TODO: /// It is possible that the load gets "widened". Suppose that originally each /// load loads `k` bytes and `PointerOps` can be arranged as follows (`%s` is /// constant): %b + 0 * %s + 0 %b + 0 * %s + 1 %b + 0 * %s + 2 @@ -6921,36 +6920,105 @@ bool BoUpSLP::isStridedLoad(ArrayRef PointerOps, Type *ScalarTy, } bool BoUpSLP::analyzeConstantStrideCandidate( - const ArrayRef PointerOps, Type *ScalarTy, Align Alignment, + ArrayRef PointerOps, Type *ElemTy, Align Alignment, const SmallVectorImpl &SortedIndices, const int64_t Diff, Value *Ptr0, Value *PtrN, StridedPtrInfo &SPtrInfo) const { - const size_t Sz = PointerOps.size(); - if (!isStridedLoad(PointerOps, ScalarTy, Alignment, Diff, Sz)) - return false; + const unsigned Sz = PointerOps.size(); + SmallVector SortedOffsetsFromBase(Sz); + // Go through `PointerOps` in sorted order and record offsets from `Ptr0`. + for (unsigned I : seq(Sz)) { + Value *Ptr = + SortedIndices.empty() ? PointerOps[I] : PointerOps[SortedIndices[I]]; + SortedOffsetsFromBase[I] = + *getPointersDiff(ElemTy, Ptr0, ElemTy, Ptr, *DL, *SE); + } + assert(SortedOffsetsFromBase.size() > 1 && + "Trying to generate strided load for less than 2 loads"); + // The code below checks that `SortedOffsetsFromBase` looks as follows: + // ``` + // [ + // (e_{0, 0}, e_{0, 1}, ..., e_{0, GroupSize - 1}), // first group + // (e_{1, 0}, e_{1, 1}, ..., e_{1, GroupSize - 1}), // secon group + // ... + // (e_{NumGroups - 1, 0}, e_{NumGroups - 1, 1}, ..., e_{NumGroups - 1, + // GroupSize - 1}), // last group + // ] + // ``` + // The distance between consecutive elements within each group should all be + // the same `StrideWithinGroup`. The distance between the first elements of + // consecutive groups should all be the same `StrideBetweenGroups`. + + int64_t StrideWithinGroup = + SortedOffsetsFromBase[1] - SortedOffsetsFromBase[0]; + // Determine size of the first group. Later we will check that all other + // groups have the same size. + unsigned GroupSize = 1; + for (; GroupSize != SortedOffsetsFromBase.size(); ++GroupSize) { + if (SortedOffsetsFromBase[GroupSize] - + SortedOffsetsFromBase[GroupSize - 1] != + StrideWithinGroup) + break; + } + unsigned VecSz = Sz; + Type *ScalarTy = ElemTy; + int64_t StrideIntVal = StrideWithinGroup; + FixedVectorType *StridedLoadTy = getWidenedType(ScalarTy, VecSz); - int64_t Stride = Diff / static_cast(Sz - 1); + // Quick detour: at this point we can say what the type of strided load would + // be if all the checks pass. Check if this type is legal for the target. + bool NeedsWidening = Sz != GroupSize; + if (NeedsWidening) { + if (Sz % GroupSize != 0) + return false; + VecSz = Sz / GroupSize; - // Iterate through all pointers and check if all distances are - // unique multiple of Dist. - SmallSet Dists; - for (Value *Ptr : PointerOps) { - int64_t Dist = 0; - if (Ptr == PtrN) - Dist = Diff; - else if (Ptr != Ptr0) - Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE); - // If the strides are not the same or repeated, we can't - // vectorize. - if (((Dist / Stride) * Stride) != Dist || !Dists.insert(Dist).second) - break; + if (StrideWithinGroup != 1) + return false; + unsigned VecSz = Sz / GroupSize; + ScalarTy = Type::getIntNTy(SE->getContext(), + DL->getTypeSizeInBits(ElemTy).getFixedValue() * + GroupSize); + StridedLoadTy = getWidenedType(ScalarTy, VecSz); } - if (Dists.size() == Sz) { - Type *StrideTy = DL->getIndexType(Ptr0->getType()); - SPtrInfo.StrideVal = ConstantInt::get(StrideTy, Stride); - SPtrInfo.Ty = getWidenedType(ScalarTy, Sz); - return true; + + if (!isStridedLoad(PointerOps, ScalarTy, Alignment, Diff, VecSz)) + return false; + + if (NeedsWidening) { + // Continue with checking the "shape" of `SortedOffsetsFromBase`. + // Check that the strides between groups are all the same. + unsigned CurrentGroupStartIdx = GroupSize; + int64_t StrideBetweenGroups = + SortedOffsetsFromBase[GroupSize] - SortedOffsetsFromBase[0]; + StrideIntVal = StrideBetweenGroups; + for (; CurrentGroupStartIdx < Sz; CurrentGroupStartIdx += GroupSize) { + if (SortedOffsetsFromBase[CurrentGroupStartIdx] - + SortedOffsetsFromBase[CurrentGroupStartIdx - GroupSize] != + StrideBetweenGroups) + return false; + } + + auto CheckGroup = [&](const unsigned StartIdx, const unsigned GroupSize0, + const int64_t StrideWithinGroup) -> bool { + unsigned GroupEndIdx = StartIdx + 1; + for (; GroupEndIdx != Sz; ++GroupEndIdx) { + if (SortedOffsetsFromBase[GroupEndIdx] - + SortedOffsetsFromBase[GroupEndIdx - 1] != + StrideWithinGroup) + break; + } + return GroupEndIdx - StartIdx == GroupSize0; + }; + for (unsigned I = 0; I < Sz; I += GroupSize) { + if (!CheckGroup(I, GroupSize, StrideWithinGroup)) + return false; + } } - return false; + + Type *StrideTy = DL->getIndexType(Ptr0->getType()); + SPtrInfo.StrideVal = ConstantInt::get(StrideTy, StrideIntVal); + SPtrInfo.Ty = StridedLoadTy; + return true; } bool BoUpSLP::analyzeRtStrideCandidate(ArrayRef PointerOps, @@ -14972,11 +15040,19 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef VectorizedVals, } break; case TreeEntry::StridedVectorize: { + const StridedPtrInfo &SPtrInfo = TreeEntryToStridedPtrInfoMap.at(E); + FixedVectorType *StridedLoadTy = SPtrInfo.Ty; + assert(StridedLoadTy && "Missing StridedPoinerInfo for tree entry."); Align CommonAlignment = computeCommonAlignment(UniqueValues.getArrayRef()); VecLdCost = TTI->getStridedMemoryOpCost( - Instruction::Load, VecTy, LI0->getPointerOperand(), + Instruction::Load, StridedLoadTy, LI0->getPointerOperand(), /*VariableMask=*/false, CommonAlignment, CostKind); + if (StridedLoadTy != VecTy) + VecLdCost += + TTI->getCastInstrCost(Instruction::BitCast, VecTy, StridedLoadTy, + getCastContextHint(*E), CostKind); + break; } case TreeEntry::CompressVectorize: { @@ -19743,6 +19819,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { ? NewLI : ::propagateMetadata(NewLI, E->Scalars); + if (StridedLoadTy != VecTy) + V = Builder.CreateBitOrPointerCast(V, VecTy); V = FinalShuffle(V, E); E->VectorizedValue = V; ++NumVectorInstructions; diff --git a/llvm/test/Transforms/SLPVectorizer/RISCV/basic-strided-loads.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/basic-strided-loads.ll index f8229b3555653..f3f9191a6fdc7 100644 --- a/llvm/test/Transforms/SLPVectorizer/RISCV/basic-strided-loads.ll +++ b/llvm/test/Transforms/SLPVectorizer/RISCV/basic-strided-loads.ll @@ -1,6 +1,6 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 -; RUN: opt -mtriple=riscv64 -mattr=+m,+v -passes=slp-vectorizer -S < %s | FileCheck %s +; RUN: opt -mtriple=riscv64 -mattr=+m,+v,+unaligned-vector-mem -passes=slp-vectorizer -S < %s | FileCheck %s define void @const_stride_1_no_reordering(ptr %pl, ptr %ps) { ; CHECK-LABEL: define void @const_stride_1_no_reordering( @@ -621,21 +621,9 @@ define void @constant_stride_widen_no_reordering(ptr %pl, i64 %stride, ptr %ps) ; CHECK-LABEL: define void @constant_stride_widen_no_reordering( ; CHECK-SAME: ptr [[PL:%.*]], i64 [[STRIDE:%.*]], ptr [[PS:%.*]]) #[[ATTR0]] { ; CHECK-NEXT: [[GEP_L0:%.*]] = getelementptr inbounds i8, ptr [[PL]], i64 0 -; CHECK-NEXT: [[GEP_L4:%.*]] = getelementptr inbounds i8, ptr [[PL]], i64 100 -; CHECK-NEXT: [[GEP_L8:%.*]] = getelementptr inbounds i8, ptr [[PL]], i64 200 -; CHECK-NEXT: [[GEP_L12:%.*]] = getelementptr inbounds i8, ptr [[PL]], i64 300 ; CHECK-NEXT: [[GEP_S0:%.*]] = getelementptr inbounds i8, ptr [[PS]], i64 0 -; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i8>, ptr [[GEP_L0]], align 1 -; CHECK-NEXT: [[TMP2:%.*]] = load <4 x i8>, ptr [[GEP_L4]], align 1 -; CHECK-NEXT: [[TMP3:%.*]] = load <4 x i8>, ptr [[GEP_L8]], align 1 -; CHECK-NEXT: [[TMP4:%.*]] = load <4 x i8>, ptr [[GEP_L12]], align 1 -; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x i8> [[TMP1]], <4 x i8> poison, <16 x i32> -; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <4 x i8> [[TMP2]], <4 x i8> poison, <16 x i32> -; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <4 x i8> [[TMP1]], <4 x i8> [[TMP2]], <16 x i32> -; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <4 x i8> [[TMP3]], <4 x i8> poison, <16 x i32> -; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <16 x i8> [[TMP7]], <16 x i8> [[TMP11]], <16 x i32> -; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <4 x i8> [[TMP4]], <4 x i8> poison, <16 x i32> -; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <16 x i8> [[TMP9]], <16 x i8> [[TMP10]], <16 x i32> +; CHECK-NEXT: [[TMP1:%.*]] = call <4 x i32> @llvm.experimental.vp.strided.load.v4i32.p0.i64(ptr align 1 [[GEP_L0]], i64 100, <4 x i1> splat (i1 true), i32 4) +; CHECK-NEXT: [[TMP8:%.*]] = bitcast <4 x i32> [[TMP1]] to <16 x i8> ; CHECK-NEXT: store <16 x i8> [[TMP8]], ptr [[GEP_S0]], align 1 ; CHECK-NEXT: ret void ;