4545// ===----------------------------------------------------------------------===//
4646
4747#include " llvm/ADT/ArrayRef.h"
48+ #include " llvm/ADT/BitVector.h"
4849#include " llvm/ADT/DenseMap.h"
4950#include " llvm/ADT/SetVector.h"
5051#include " llvm/ADT/SmallVector.h"
@@ -100,11 +101,11 @@ class InterleavedAccessImpl {
100101 unsigned MaxFactor = 0u ;
101102
102103 // / Transform an interleaved load into target specific intrinsics.
103- bool lowerInterleavedLoad (LoadInst *LI ,
104+ bool lowerInterleavedLoad (Instruction *LoadOp ,
104105 SmallSetVector<Instruction *, 32 > &DeadInsts);
105106
106107 // / Transform an interleaved store into target specific intrinsics.
107- bool lowerInterleavedStore (StoreInst *SI ,
108+ bool lowerInterleavedStore (Instruction *StoreOp ,
108109 SmallSetVector<Instruction *, 32 > &DeadInsts);
109110
110111 // / Transform a load and a deinterleave intrinsic into target specific
@@ -131,7 +132,7 @@ class InterleavedAccessImpl {
131132 // / made.
132133 bool replaceBinOpShuffles (ArrayRef<ShuffleVectorInst *> BinOpShuffles,
133134 SmallVectorImpl<ShuffleVectorInst *> &Shuffles,
134- LoadInst *LI);
135+ Instruction *LI);
135136};
136137
137138class InterleavedAccess : public FunctionPass {
@@ -250,10 +251,23 @@ static bool isReInterleaveMask(ShuffleVectorInst *SVI, unsigned &Factor,
250251}
251252
252253bool InterleavedAccessImpl::lowerInterleavedLoad (
253- LoadInst *LI , SmallSetVector<Instruction *, 32 > &DeadInsts) {
254- if (!LI-> isSimple () || isa<ScalableVectorType>(LI ->getType ()))
254+ Instruction *LoadOp , SmallSetVector<Instruction *, 32 > &DeadInsts) {
255+ if (isa<ScalableVectorType>(LoadOp ->getType ()))
255256 return false ;
256257
258+ if (auto *LI = dyn_cast<LoadInst>(LoadOp)) {
259+ if (!LI->isSimple ())
260+ return false ;
261+ } else if (auto *VPLoad = dyn_cast<VPIntrinsic>(LoadOp)) {
262+ assert (VPLoad->getIntrinsicID () == Intrinsic::vp_load);
263+ // Require a constant mask and evl.
264+ if (!isa<ConstantVector>(VPLoad->getArgOperand (1 )) ||
265+ !isa<ConstantInt>(VPLoad->getArgOperand (2 )))
266+ return false ;
267+ } else {
268+ llvm_unreachable (" unsupported load operation" );
269+ }
270+
257271 // Check if all users of this load are shufflevectors. If we encounter any
258272 // users that are extractelement instructions or binary operators, we save
259273 // them to later check if they can be modified to extract from one of the
@@ -265,7 +279,7 @@ bool InterleavedAccessImpl::lowerInterleavedLoad(
265279 // binop are the same load.
266280 SmallSetVector<ShuffleVectorInst *, 4 > BinOpShuffles;
267281
268- for (auto *User : LI ->users ()) {
282+ for (auto *User : LoadOp ->users ()) {
269283 auto *Extract = dyn_cast<ExtractElementInst>(User);
270284 if (Extract && isa<ConstantInt>(Extract->getIndexOperand ())) {
271285 Extracts.push_back (Extract);
@@ -294,13 +308,31 @@ bool InterleavedAccessImpl::lowerInterleavedLoad(
294308 unsigned Factor, Index;
295309
296310 unsigned NumLoadElements =
297- cast<FixedVectorType>(LI ->getType ())->getNumElements ();
311+ cast<FixedVectorType>(LoadOp ->getType ())->getNumElements ();
298312 auto *FirstSVI = Shuffles.size () > 0 ? Shuffles[0 ] : BinOpShuffles[0 ];
299313 // Check if the first shufflevector is DE-interleave shuffle.
300314 if (!isDeInterleaveMask (FirstSVI->getShuffleMask (), Factor, Index, MaxFactor,
301315 NumLoadElements))
302316 return false ;
303317
318+ // If this is a vp.load, record its mask (NOT shuffle mask).
319+ BitVector MaskedIndices (NumLoadElements);
320+ if (auto *VPLoad = dyn_cast<VPIntrinsic>(LoadOp)) {
321+ auto *Mask = cast<ConstantVector>(VPLoad->getArgOperand (1 ));
322+ assert (cast<FixedVectorType>(Mask->getType ())->getNumElements () ==
323+ NumLoadElements);
324+ if (auto *Splat = Mask->getSplatValue ()) {
325+ // All-zeros mask, bail out early.
326+ if (Splat->isZeroValue ())
327+ return false ;
328+ } else {
329+ for (unsigned i = 0U ; i < NumLoadElements; ++i) {
330+ if (Mask->getAggregateElement (i)->isZeroValue ())
331+ MaskedIndices.set (i);
332+ }
333+ }
334+ }
335+
304336 // Holds the corresponding index for each DE-interleave shuffle.
305337 SmallVector<unsigned , 4 > Indices;
306338
@@ -327,9 +359,9 @@ bool InterleavedAccessImpl::lowerInterleavedLoad(
327359
328360 assert (Shuffle->getShuffleMask ().size () <= NumLoadElements);
329361
330- if (cast<Instruction>(Shuffle->getOperand (0 ))->getOperand (0 ) == LI )
362+ if (cast<Instruction>(Shuffle->getOperand (0 ))->getOperand (0 ) == LoadOp )
331363 Indices.push_back (Index);
332- if (cast<Instruction>(Shuffle->getOperand (0 ))->getOperand (1 ) == LI )
364+ if (cast<Instruction>(Shuffle->getOperand (0 ))->getOperand (1 ) == LoadOp )
333365 Indices.push_back (Index);
334366 }
335367
@@ -339,25 +371,61 @@ bool InterleavedAccessImpl::lowerInterleavedLoad(
339371 return false ;
340372
341373 bool BinOpShuffleChanged =
342- replaceBinOpShuffles (BinOpShuffles.getArrayRef (), Shuffles, LI);
374+ replaceBinOpShuffles (BinOpShuffles.getArrayRef (), Shuffles, LoadOp);
375+
376+ // Check if we extract only the unmasked elements.
377+ if (MaskedIndices.any ()) {
378+ if (any_of (Shuffles, [&](const auto *Shuffle) {
379+ ArrayRef<int > ShuffleMask = Shuffle->getShuffleMask ();
380+ for (int Idx : ShuffleMask) {
381+ if (Idx < 0 )
382+ continue ;
383+ if (MaskedIndices.test (unsigned (Idx)))
384+ return true ;
385+ }
386+ return false ;
387+ })) {
388+ LLVM_DEBUG (dbgs () << " IA: trying to extract a masked element through "
389+ << " shufflevector\n " );
390+ return false ;
391+ }
392+ }
393+ // Check if we extract only the elements within evl.
394+ if (auto *VPLoad = dyn_cast<VPIntrinsic>(LoadOp)) {
395+ uint64_t EVL = cast<ConstantInt>(VPLoad->getArgOperand (2 ))->getZExtValue ();
396+ if (any_of (Shuffles, [&](const auto *Shuffle) {
397+ ArrayRef<int > ShuffleMask = Shuffle->getShuffleMask ();
398+ for (int Idx : ShuffleMask) {
399+ if (Idx < 0 )
400+ continue ;
401+ if (unsigned (Idx) >= EVL)
402+ return true ;
403+ }
404+ return false ;
405+ })) {
406+ LLVM_DEBUG (
407+ dbgs () << " IA: trying to extract an element out of EVL range\n " );
408+ return false ;
409+ }
410+ }
343411
344- LLVM_DEBUG (dbgs () << " IA: Found an interleaved load: " << *LI << " \n " );
412+ LLVM_DEBUG (dbgs () << " IA: Found an interleaved load: " << *LoadOp << " \n " );
345413
346414 // Try to create target specific intrinsics to replace the load and shuffles.
347- if (!TLI->lowerInterleavedLoad (LI , Shuffles, Indices, Factor)) {
415+ if (!TLI->lowerInterleavedLoad (LoadOp , Shuffles, Indices, Factor)) {
348416 // If Extracts is not empty, tryReplaceExtracts made changes earlier.
349417 return !Extracts.empty () || BinOpShuffleChanged;
350418 }
351419
352420 DeadInsts.insert_range (Shuffles);
353421
354- DeadInsts.insert (LI );
422+ DeadInsts.insert (LoadOp );
355423 return true ;
356424}
357425
358426bool InterleavedAccessImpl::replaceBinOpShuffles (
359427 ArrayRef<ShuffleVectorInst *> BinOpShuffles,
360- SmallVectorImpl<ShuffleVectorInst *> &Shuffles, LoadInst *LI ) {
428+ SmallVectorImpl<ShuffleVectorInst *> &Shuffles, Instruction *LoadOp ) {
361429 for (auto *SVI : BinOpShuffles) {
362430 BinaryOperator *BI = cast<BinaryOperator>(SVI->getOperand (0 ));
363431 Type *BIOp0Ty = BI->getOperand (0 )->getType ();
@@ -380,9 +448,9 @@ bool InterleavedAccessImpl::replaceBinOpShuffles(
380448 << " \n With : " << *NewSVI1 << " \n And : "
381449 << *NewSVI2 << " \n And : " << *NewBI << " \n " );
382450 RecursivelyDeleteTriviallyDeadInstructions (SVI);
383- if (NewSVI1->getOperand (0 ) == LI )
451+ if (NewSVI1->getOperand (0 ) == LoadOp )
384452 Shuffles.push_back (NewSVI1);
385- if (NewSVI2->getOperand (0 ) == LI )
453+ if (NewSVI2->getOperand (0 ) == LoadOp )
386454 Shuffles.push_back (NewSVI2);
387455 }
388456
@@ -454,27 +522,79 @@ bool InterleavedAccessImpl::tryReplaceExtracts(
454522}
455523
456524bool InterleavedAccessImpl::lowerInterleavedStore (
457- StoreInst *SI, SmallSetVector<Instruction *, 32 > &DeadInsts) {
458- if (!SI->isSimple ())
459- return false ;
525+ Instruction *StoreOp, SmallSetVector<Instruction *, 32 > &DeadInsts) {
526+ Value *StoredValue;
527+ if (auto *SI = dyn_cast<StoreInst>(StoreOp)) {
528+ if (!SI->isSimple ())
529+ return false ;
530+ StoredValue = SI->getValueOperand ();
531+ } else if (auto *VPStore = dyn_cast<VPIntrinsic>(StoreOp)) {
532+ assert (VPStore->getIntrinsicID () == Intrinsic::vp_store);
533+ // Require a constant mask and evl.
534+ if (!isa<ConstantVector>(VPStore->getArgOperand (2 )) ||
535+ !isa<ConstantInt>(VPStore->getArgOperand (3 )))
536+ return false ;
537+ StoredValue = VPStore->getArgOperand (0 );
538+ } else {
539+ llvm_unreachable (" unsupported store operation" );
540+ }
460541
461- auto *SVI = dyn_cast<ShuffleVectorInst>(SI-> getValueOperand () );
542+ auto *SVI = dyn_cast<ShuffleVectorInst>(StoredValue );
462543 if (!SVI || !SVI->hasOneUse () || isa<ScalableVectorType>(SVI->getType ()))
463544 return false ;
464545
546+ unsigned NumStoredElements =
547+ cast<FixedVectorType>(SVI->getType ())->getNumElements ();
548+ // If this is a vp.store, record its mask (NOT shuffle mask).
549+ BitVector MaskedIndices (NumStoredElements);
550+ if (auto *VPStore = dyn_cast<VPIntrinsic>(StoreOp)) {
551+ auto *Mask = cast<ConstantVector>(VPStore->getArgOperand (2 ));
552+ assert (cast<FixedVectorType>(Mask->getType ())->getNumElements () ==
553+ NumStoredElements);
554+ if (auto *Splat = Mask->getSplatValue ()) {
555+ // All-zeros mask, bail out early.
556+ if (Splat->isZeroValue ())
557+ return false ;
558+ } else {
559+ for (unsigned i = 0U ; i < NumStoredElements; ++i) {
560+ if (Mask->getAggregateElement (i)->isZeroValue ())
561+ MaskedIndices.set (i);
562+ }
563+ }
564+ }
565+
465566 // Check if the shufflevector is RE-interleave shuffle.
466567 unsigned Factor;
467568 if (!isReInterleaveMask (SVI, Factor, MaxFactor))
468569 return false ;
469570
470- LLVM_DEBUG (dbgs () << " IA: Found an interleaved store: " << *SI << " \n " );
571+ // Check if we store only the unmasked elements.
572+ if (MaskedIndices.any ()) {
573+ if (any_of (SVI->getShuffleMask (), [&](int Idx) {
574+ return Idx >= 0 && MaskedIndices.test (unsigned (Idx));
575+ })) {
576+ LLVM_DEBUG (dbgs () << " IA: trying to store a masked element\n " );
577+ return false ;
578+ }
579+ }
580+ // Check if we store only the elements within evl.
581+ if (auto *VPStore = dyn_cast<VPIntrinsic>(StoreOp)) {
582+ uint64_t EVL = cast<ConstantInt>(VPStore->getArgOperand (3 ))->getZExtValue ();
583+ if (any_of (SVI->getShuffleMask (),
584+ [&](int Idx) { return Idx >= 0 && unsigned (Idx) >= EVL; })) {
585+ LLVM_DEBUG (dbgs () << " IA: trying to store an element out of EVL range\n " );
586+ return false ;
587+ }
588+ }
589+
590+ LLVM_DEBUG (dbgs () << " IA: Found an interleaved store: " << *StoreOp << " \n " );
471591
472592 // Try to create target specific intrinsics to replace the store and shuffle.
473- if (!TLI->lowerInterleavedStore (SI , SVI, Factor))
593+ if (!TLI->lowerInterleavedStore (StoreOp , SVI, Factor))
474594 return false ;
475595
476596 // Already have a new target specific interleaved store. Erase the old store.
477- DeadInsts.insert (SI );
597+ DeadInsts.insert (StoreOp );
478598 DeadInsts.insert (SVI);
479599 return true ;
480600}
@@ -766,12 +886,15 @@ bool InterleavedAccessImpl::runOnFunction(Function &F) {
766886 SmallSetVector<Instruction *, 32 > DeadInsts;
767887 bool Changed = false ;
768888
889+ using namespace PatternMatch ;
769890 for (auto &I : instructions (F)) {
770- if (auto *LI = dyn_cast<LoadInst>(&I))
771- Changed |= lowerInterleavedLoad (LI, DeadInsts);
891+ if (match (&I, m_CombineOr (m_Load (m_Value ()),
892+ m_Intrinsic<Intrinsic::vp_load>())))
893+ Changed |= lowerInterleavedLoad (&I, DeadInsts);
772894
773- if (auto *SI = dyn_cast<StoreInst>(&I))
774- Changed |= lowerInterleavedStore (SI, DeadInsts);
895+ if (match (&I, m_CombineOr (m_Store (m_Value (), m_Value ()),
896+ m_Intrinsic<Intrinsic::vp_store>())))
897+ Changed |= lowerInterleavedStore (&I, DeadInsts);
775898
776899 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
777900 // At present, we only have intrinsics to represent (de)interleaving
0 commit comments