|
97 | 97 | #include <cctype> |
98 | 98 | #include <cstdint> |
99 | 99 | #include <cstdlib> |
| 100 | +#include <deque> |
100 | 101 | #include <iterator> |
101 | 102 | #include <limits> |
102 | 103 | #include <optional> |
@@ -18010,12 +18011,14 @@ bool AArch64TargetLowering::lowerInterleavedStore(Instruction *Store, |
18010 | 18011 | ShuffleVectorInst *SVI, |
18011 | 18012 | unsigned Factor, |
18012 | 18013 | const APInt &GapMask) const { |
18013 | | - |
18014 | | - assert(Factor >= 2 && Factor <= getMaxSupportedInterleaveFactor() && |
18015 | | - "Invalid interleave factor"); |
| 18014 | + assert(Factor >= 2 && "Invalid interleave factor"); |
18016 | 18015 | auto *SI = dyn_cast<StoreInst>(Store); |
18017 | 18016 | if (!SI) |
18018 | 18017 | return false; |
| 18018 | + |
| 18019 | + if (Factor > getMaxSupportedInterleaveFactor()) |
| 18020 | + return lowerInterleavedStoreWithShuffle(SI, SVI, Factor); |
| 18021 | + |
18019 | 18022 | assert(!LaneMask && GapMask.popcount() == Factor && |
18020 | 18023 | "Unexpected mask on store"); |
18021 | 18024 |
|
@@ -18161,6 +18164,146 @@ bool AArch64TargetLowering::lowerInterleavedStore(Instruction *Store, |
18161 | 18164 | return true; |
18162 | 18165 | } |
18163 | 18166 |
|
| 18167 | +/// If the interleaved vector elements are greater than supported MaxFactor, |
| 18168 | +/// interleaving the data with additional shuffles can be used to |
| 18169 | +/// achieve the same. |
| 18170 | +/// |
| 18171 | +/// Consider the following data with 8 interleaves which are shuffled to store |
| 18172 | +/// stN instructions. Data needs to be stored in this order: |
| 18173 | +/// [v0, v1, v2, v3, v4, v5, v6, v7] |
| 18174 | +/// |
| 18175 | +/// v0 v4 v2 v6 v1 v5 v3 v7 |
| 18176 | +/// | | | | | | | | |
| 18177 | +/// \ / \ / \ / \ / |
| 18178 | +/// [zip v0,v4] [zip v2,v6] [zip v1,v5] [zip v3,v7] ==> stN = 4 |
| 18179 | +/// | | | | |
| 18180 | +/// \ / \ / |
| 18181 | +/// \ / \ / |
| 18182 | +/// \ / \ / |
| 18183 | +/// [zip [v0,v2,v4,v6]] [zip [v1,v3,v5,v7]] ==> stN = 2 |
| 18184 | +/// |
| 18185 | +/// For stN = 4, upper half of interleaved data V0, V1, V2, V3 is stored |
| 18186 | +/// with one st4 instruction. Lower half, i.e, V4, V5, V6, V7 is stored with |
| 18187 | +/// another st4. |
| 18188 | +/// |
| 18189 | +/// For stN = 2, upper half of interleaved data V0, V1 is stored |
| 18190 | +/// with one st2 instruction. Second set V2, V3 is stored with another st2. |
| 18191 | +/// Total of 4 st2's are required here. |
| 18192 | +bool AArch64TargetLowering::lowerInterleavedStoreWithShuffle( |
| 18193 | + StoreInst *SI, ShuffleVectorInst *SVI, unsigned Factor) const { |
| 18194 | + unsigned MaxSupportedFactor = getMaxSupportedInterleaveFactor(); |
| 18195 | + |
| 18196 | + auto *VecTy = cast<FixedVectorType>(SVI->getType()); |
| 18197 | + assert(VecTy->getNumElements() % Factor == 0 && "Invalid interleaved store"); |
| 18198 | + |
| 18199 | + unsigned LaneLen = VecTy->getNumElements() / Factor; |
| 18200 | + Type *EltTy = VecTy->getElementType(); |
| 18201 | + auto *SubVecTy = FixedVectorType::get(EltTy, Factor); |
| 18202 | + |
| 18203 | + const DataLayout &DL = SI->getModule()->getDataLayout(); |
| 18204 | + bool UseScalable; |
| 18205 | + |
| 18206 | + // Skip if we do not have NEON and skip illegal vector types. We can |
| 18207 | + // "legalize" wide vector types into multiple interleaved accesses as long as |
| 18208 | + // the vector types are divisible by 128. |
| 18209 | + if (!Subtarget->hasNEON() || |
| 18210 | + !isLegalInterleavedAccessType(SubVecTy, DL, UseScalable)) |
| 18211 | + return false; |
| 18212 | + |
| 18213 | + if (UseScalable) |
| 18214 | + return false; |
| 18215 | + |
| 18216 | + std::deque<Value *> Shuffles; |
| 18217 | + Shuffles.push_back(SVI); |
| 18218 | + unsigned ConcatLevel = Factor; |
| 18219 | + // Getting all the interleaved operands. |
| 18220 | + while (ConcatLevel > 1) { |
| 18221 | + unsigned InterleavedOperands = Shuffles.size(); |
| 18222 | + for (unsigned Ops = 0; Ops < InterleavedOperands; Ops++) { |
| 18223 | + auto *V = Shuffles.front(); |
| 18224 | + Shuffles.pop_front(); |
| 18225 | + if (isa<ConstantAggregateZero, UndefValue>(V)) { |
| 18226 | + VectorType *Ty = cast<VectorType>(V->getType()); |
| 18227 | + auto *HalfTy = VectorType::getHalfElementsVectorType(Ty); |
| 18228 | + Value *SplitValue = nullptr; |
| 18229 | + if (isa<ConstantAggregateZero>(V)) |
| 18230 | + SplitValue = ConstantAggregateZero::get(HalfTy); |
| 18231 | + else if (isa<PoisonValue>(V)) |
| 18232 | + SplitValue = PoisonValue::get(HalfTy); |
| 18233 | + else if (isa<UndefValue>(V)) |
| 18234 | + SplitValue = UndefValue::get(HalfTy); |
| 18235 | + Shuffles.push_back(SplitValue); |
| 18236 | + Shuffles.push_back(SplitValue); |
| 18237 | + continue; |
| 18238 | + } |
| 18239 | + |
| 18240 | + ShuffleVectorInst *SFL = dyn_cast<ShuffleVectorInst>(V); |
| 18241 | + if (!SFL) |
| 18242 | + return false; |
| 18243 | + if (SVI != SFL && !SFL->isConcat()) |
| 18244 | + return false; |
| 18245 | + |
| 18246 | + Value *Op0 = SFL->getOperand(0); |
| 18247 | + Value *Op1 = SFL->getOperand(1); |
| 18248 | + |
| 18249 | + Shuffles.push_back(dyn_cast<Value>(Op0)); |
| 18250 | + Shuffles.push_back(dyn_cast<Value>(Op1)); |
| 18251 | + } |
| 18252 | + ConcatLevel >>= 1; |
| 18253 | + } |
| 18254 | + |
| 18255 | + IRBuilder<> Builder(SI); |
| 18256 | + auto Mask = createInterleaveMask(LaneLen, 2); |
| 18257 | + SmallVector<int, 16> UpperHalfMask(LaneLen), LowerHalfMask(LaneLen); |
| 18258 | + for (unsigned Idx = 0; Idx < LaneLen; Idx++) { |
| 18259 | + LowerHalfMask[Idx] = Mask[Idx]; |
| 18260 | + UpperHalfMask[Idx] = Mask[Idx + LaneLen]; |
| 18261 | + } |
| 18262 | + |
| 18263 | + unsigned InterleaveFactor = Factor >> 1; |
| 18264 | + while (InterleaveFactor >= MaxSupportedFactor) { |
| 18265 | + std::deque<Value *> ShufflesIntermediate; |
| 18266 | + ShufflesIntermediate.resize(Factor); |
| 18267 | + for (unsigned Idx = 0; Idx < Factor; Idx += (InterleaveFactor * 2)) { |
| 18268 | + for (unsigned GroupIdx = 0; GroupIdx < InterleaveFactor; GroupIdx++) { |
| 18269 | + auto *Shuffle = Builder.CreateShuffleVector( |
| 18270 | + Shuffles[Idx + GroupIdx], |
| 18271 | + Shuffles[Idx + GroupIdx + InterleaveFactor], LowerHalfMask); |
| 18272 | + ShufflesIntermediate[Idx + GroupIdx] = Shuffle; |
| 18273 | + Shuffle = Builder.CreateShuffleVector( |
| 18274 | + Shuffles[Idx + GroupIdx], |
| 18275 | + Shuffles[Idx + GroupIdx + InterleaveFactor], UpperHalfMask); |
| 18276 | + ShufflesIntermediate[Idx + GroupIdx + InterleaveFactor] = Shuffle; |
| 18277 | + } |
| 18278 | + } |
| 18279 | + Shuffles = ShufflesIntermediate; |
| 18280 | + InterleaveFactor >>= 1; |
| 18281 | + } |
| 18282 | + |
| 18283 | + Type *PtrTy = SI->getPointerOperandType(); |
| 18284 | + auto *STVTy = FixedVectorType::get(SubVecTy->getElementType(), LaneLen); |
| 18285 | + |
| 18286 | + Value *BaseAddr = SI->getPointerOperand(); |
| 18287 | + Function *StNFunc = getStructuredStoreFunction( |
| 18288 | + SI->getModule(), MaxSupportedFactor, UseScalable, STVTy, PtrTy); |
| 18289 | + for (unsigned N = 0; N < (Factor / MaxSupportedFactor); N++) { |
| 18290 | + SmallVector<Value *, 5> Ops; |
| 18291 | + for (unsigned OpIdx = 0; OpIdx < MaxSupportedFactor; OpIdx++) |
| 18292 | + Ops.push_back(Shuffles[N * MaxSupportedFactor + OpIdx]); |
| 18293 | + |
| 18294 | + if (N > 0) { |
| 18295 | + // We will compute the pointer operand of each store from the original |
| 18296 | + // base address using GEPs. Cast the base address to a pointer to the |
| 18297 | + // scalar element type. |
| 18298 | + BaseAddr = Builder.CreateConstGEP1_32( |
| 18299 | + SubVecTy->getElementType(), BaseAddr, LaneLen * MaxSupportedFactor); |
| 18300 | + } |
| 18301 | + Ops.push_back(Builder.CreateBitCast(BaseAddr, PtrTy)); |
| 18302 | + Builder.CreateCall(StNFunc, Ops); |
| 18303 | + } |
| 18304 | + return true; |
| 18305 | +} |
| 18306 | + |
18164 | 18307 | bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad( |
18165 | 18308 | Instruction *Load, Value *Mask, IntrinsicInst *DI) const { |
18166 | 18309 | const unsigned Factor = getDeinterleaveIntrinsicFactor(DI->getIntrinsicID()); |
|
0 commit comments