@@ -5923,9 +5923,9 @@ static bool isMaskedLoadCompress(
59235923 // Check for very large distances between elements.
59245924 if (*Diff / Sz >= MaxRegSize / 8)
59255925 return false;
5926- Align CommonAlignment = computeCommonAlignment<LoadInst>(VL);
59275926 LoadVecTy = getWidenedType(ScalarTy, *Diff + 1);
59285927 auto *LI = cast<LoadInst>(Order.empty() ? VL.front() : VL[Order.front()]);
5928+ Align CommonAlignment = LI->getAlign();
59295929 IsMasked = !isSafeToLoadUnconditionally(
59305930 Ptr0, LoadVecTy, CommonAlignment, DL,
59315931 cast<LoadInst>(Order.empty() ? VL.back() : VL[Order.back()]), &AC, &DT,
@@ -5964,26 +5964,28 @@ static bool isMaskedLoadCompress(
59645964 TTI.getMaskedMemoryOpCost(Instruction::Load, LoadVecTy, CommonAlignment,
59655965 LI->getPointerAddressSpace(), CostKind);
59665966 } else {
5967- CommonAlignment = LI->getAlign();
59685967 LoadCost =
59695968 TTI.getMemoryOpCost(Instruction::Load, LoadVecTy, CommonAlignment,
59705969 LI->getPointerAddressSpace(), CostKind);
59715970 }
5972- if (IsStrided) {
5971+ if (IsStrided && !IsMasked ) {
59735972 // Check for potential segmented(interleaved) loads.
5974- if (TTI.isLegalInterleavedAccessType(LoadVecTy, CompressMask[1],
5973+ auto *AlignedLoadVecTy = getWidenedType(
5974+ ScalarTy, getFullVectorNumberOfElements(TTI, ScalarTy, *Diff + 1));
5975+ if (TTI.isLegalInterleavedAccessType(AlignedLoadVecTy, CompressMask[1],
59755976 CommonAlignment,
59765977 LI->getPointerAddressSpace())) {
59775978 InstructionCost InterleavedCost =
59785979 VectorGEPCost + TTI.getInterleavedMemoryOpCost(
5979- Instruction::Load, LoadVecTy, CompressMask[1] ,
5980- std::nullopt, CommonAlignment,
5980+ Instruction::Load, AlignedLoadVecTy ,
5981+ CompressMask[1], std::nullopt, CommonAlignment,
59815982 LI->getPointerAddressSpace(), CostKind, IsMasked);
59825983 if (!Mask.empty())
59835984 InterleavedCost += ::getShuffleCost(TTI, TTI::SK_PermuteSingleSrc,
59845985 VecTy, Mask, CostKind);
59855986 if (InterleavedCost < GatherCost) {
59865987 InterleaveFactor = CompressMask[1];
5988+ LoadVecTy = AlignedLoadVecTy;
59875989 return true;
59885990 }
59895991 }
@@ -6001,6 +6003,24 @@ static bool isMaskedLoadCompress(
60016003 return TotalVecCost < GatherCost;
60026004}
60036005
6006+ /// Checks if the \p VL can be transformed to a (masked)load + compress or
6007+ /// (masked) interleaved load.
6008+ static bool
6009+ isMaskedLoadCompress(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
6010+ ArrayRef<unsigned> Order, const TargetTransformInfo &TTI,
6011+ const DataLayout &DL, ScalarEvolution &SE,
6012+ AssumptionCache &AC, const DominatorTree &DT,
6013+ const TargetLibraryInfo &TLI,
6014+ const function_ref<bool(Value *)> AreAllUsersVectorized) {
6015+ bool IsMasked;
6016+ unsigned InterleaveFactor;
6017+ SmallVector<int> CompressMask;
6018+ VectorType *LoadVecTy;
6019+ return isMaskedLoadCompress(VL, PointerOps, Order, TTI, DL, SE, AC, DT, TLI,
6020+ AreAllUsersVectorized, IsMasked, InterleaveFactor,
6021+ CompressMask, LoadVecTy);
6022+ }
6023+
60046024/// Checks if strided loads can be generated out of \p VL loads with pointers \p
60056025/// PointerOps:
60066026/// 1. Target with strided load support is detected.
@@ -6137,6 +6157,12 @@ BoUpSLP::canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
61376157 // Check that the sorted loads are consecutive.
61386158 if (static_cast<unsigned>(*Diff) == Sz - 1)
61396159 return LoadsState::Vectorize;
6160+ if (isMaskedLoadCompress(VL, PointerOps, Order, *TTI, *DL, *SE, *AC, *DT,
6161+ *TLI, [&](Value *V) {
6162+ return areAllUsersVectorized(
6163+ cast<Instruction>(V), UserIgnoreList);
6164+ }))
6165+ return LoadsState::CompressVectorize;
61406166 // Simple check if not a strided access - clear order.
61416167 bool IsPossibleStrided = *Diff % (Sz - 1) == 0;
61426168 // Try to generate strided load node.
@@ -6150,18 +6176,6 @@ BoUpSLP::canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
61506176 isStridedLoad(VL, PointerOps, Order, *TTI, *DL, *SE,
61516177 IsAnyPointerUsedOutGraph, *Diff))
61526178 return LoadsState::StridedVectorize;
6153- bool IsMasked;
6154- unsigned InterleaveFactor;
6155- SmallVector<int> CompressMask;
6156- VectorType *LoadVecTy;
6157- if (isMaskedLoadCompress(
6158- VL, PointerOps, Order, *TTI, *DL, *SE, *AC, *DT, *TLI,
6159- [&](Value *V) {
6160- return areAllUsersVectorized(cast<Instruction>(V),
6161- UserIgnoreList);
6162- },
6163- IsMasked, InterleaveFactor, CompressMask, LoadVecTy))
6164- return LoadsState::CompressVectorize;
61656179 }
61666180 if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) ||
61676181 TTI->forceScalarizeMaskedGather(VecTy, CommonAlignment))
@@ -13439,11 +13453,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1343913453 assert(IsVectorized && "Expected to be vectorized");
1344013454 CompressEntryToData.try_emplace(E, CompressMask, LoadVecTy,
1344113455 InterleaveFactor, IsMasked);
13442- Align CommonAlignment;
13443- if (IsMasked)
13444- CommonAlignment = computeCommonAlignment<LoadInst>(VL);
13445- else
13446- CommonAlignment = LI0->getAlign();
13456+ Align CommonAlignment = LI0->getAlign();
1344713457 if (InterleaveFactor) {
1344813458 VecLdCost = TTI->getInterleavedMemoryOpCost(
1344913459 Instruction::Load, LoadVecTy, InterleaveFactor, std::nullopt,
@@ -18049,14 +18059,11 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
1804918059 PointerOps[I] = cast<LoadInst>(V)->getPointerOperand();
1805018060 auto [CompressMask, LoadVecTy, InterleaveFactor, IsMasked] =
1805118061 CompressEntryToData.at(E);
18052- Align CommonAlignment;
18053- if (IsMasked)
18054- CommonAlignment = computeCommonAlignment<LoadInst>(E->Scalars);
18055- else
18056- CommonAlignment = LI->getAlign();
18062+ Align CommonAlignment = LI->getAlign();
1805718063 if (IsMasked) {
18064+ unsigned VF = getNumElements(LoadVecTy);
1805818065 SmallVector<Constant *> MaskValues(
18059- getNumElements(LoadVecTy) / getNumElements(LI->getType()),
18066+ VF / getNumElements(LI->getType()),
1806018067 ConstantInt::getFalse(VecTy->getContext()));
1806118068 for (int I : CompressMask)
1806218069 MaskValues[I] = ConstantInt::getTrue(VecTy->getContext());
0 commit comments