@@ -253,6 +253,21 @@ static bool isReInterleaveMask(ShuffleVectorInst *SVI, unsigned &Factor,
253253 return false ;
254254}
255255
256+ static Value *getMaskOperand (IntrinsicInst *II) {
257+ switch (II->getIntrinsicID ()) {
258+ default :
259+ llvm_unreachable (" Unexpected intrinsic" );
260+ case Intrinsic::vp_load:
261+ return II->getOperand (1 );
262+ case Intrinsic::masked_load:
263+ return II->getOperand (2 );
264+ case Intrinsic::vp_store:
265+ return II->getOperand (2 );
266+ case Intrinsic::masked_store:
267+ return II->getOperand (3 );
268+ }
269+ }
270+
256271// Return the corresponded deinterleaved mask, or nullptr if there is no valid
257272// mask.
258273static Value *getMask (Value *WideMask, unsigned Factor,
@@ -268,8 +283,12 @@ bool InterleavedAccessImpl::lowerInterleavedLoad(
268283 if (isa<ScalableVectorType>(Load->getType ()))
269284 return false ;
270285
271- if (auto *LI = dyn_cast<LoadInst>(Load);
272- LI && !LI->isSimple ())
286+ auto *LI = dyn_cast<LoadInst>(Load);
287+ auto *II = dyn_cast<IntrinsicInst>(Load);
288+ if (!LI && !II)
289+ return false ;
290+
291+ if (LI && !LI->isSimple ())
273292 return false ;
274293
275294 // Check if all users of this load are shufflevectors. If we encounter any
@@ -322,7 +341,7 @@ bool InterleavedAccessImpl::lowerInterleavedLoad(
322341 // Holds the corresponding index for each DE-interleave shuffle.
323342 SmallVector<unsigned , 4 > Indices;
324343
325- Type *VecTy = FirstSVI->getType ();
344+ VectorType *VecTy = cast<VectorType>( FirstSVI->getType () );
326345
327346 // Check if other shufflevectors are also DE-interleaved of the same type
328347 // and factor as the first shufflevector.
@@ -360,13 +379,16 @@ bool InterleavedAccessImpl::lowerInterleavedLoad(
360379 replaceBinOpShuffles (BinOpShuffles.getArrayRef (), Shuffles, Load);
361380
362381 Value *Mask = nullptr ;
363- if (auto *VPLoad = dyn_cast<VPIntrinsic>(Load)) {
364- Mask = getMask (VPLoad->getMaskParam (), Factor, cast<VectorType>(VecTy));
382+ if (LI) {
383+ LLVM_DEBUG (dbgs () << " IA: Found an interleaved load: " << *Load << " \n " );
384+ } else {
385+ // Check mask operand. Handle both all-true/false and interleaved mask.
386+ Mask = getMask (getMaskOperand (II), Factor, VecTy);
365387 if (!Mask)
366388 return false ;
367- LLVM_DEBUG ( dbgs () << " IA: Found an interleaved vp.load: " << *Load << " \n " );
368- } else {
369- LLVM_DEBUG ( dbgs () << " IA: Found an interleaved load: " << *Load << " \n " );
389+
390+ LLVM_DEBUG ( dbgs () << " IA: Found an interleaved vp.load or masked.load: "
391+ << *Load << " \n " );
370392 }
371393
372394 // Try to create target specific intrinsics to replace the load and
@@ -483,15 +505,16 @@ bool InterleavedAccessImpl::tryReplaceExtracts(
483505bool InterleavedAccessImpl::lowerInterleavedStore (
484506 Instruction *Store, SmallSetVector<Instruction *, 32 > &DeadInsts) {
485507 Value *StoredValue;
486- if (auto *SI = dyn_cast<StoreInst>(Store)) {
508+ auto *SI = dyn_cast<StoreInst>(Store);
509+ auto *II = dyn_cast<IntrinsicInst>(Store);
510+ if (SI) {
487511 if (!SI->isSimple ())
488512 return false ;
489513 StoredValue = SI->getValueOperand ();
490- } else if (auto *VPStore = dyn_cast<VPIntrinsic>(Store)) {
491- assert (VPStore->getIntrinsicID () == Intrinsic::vp_store);
492- StoredValue = VPStore->getArgOperand (0 );
493514 } else {
494- llvm_unreachable (" unsupported store operation" );
515+ assert (II->getIntrinsicID () == Intrinsic::vp_store ||
516+ II->getIntrinsicID () == Intrinsic::masked_store);
517+ StoredValue = II->getArgOperand (0 );
495518 }
496519
497520 auto *SVI = dyn_cast<ShuffleVectorInst>(StoredValue);
@@ -508,18 +531,18 @@ bool InterleavedAccessImpl::lowerInterleavedStore(
508531 " number of stored element should be a multiple of Factor" );
509532
510533 Value *Mask = nullptr ;
511- if (auto *VPStore = dyn_cast<VPIntrinsic>(Store)) {
534+ if (SI) {
535+ LLVM_DEBUG (dbgs () << " IA: Found an interleaved store: " << *Store << " \n " );
536+ } else {
537+ // Check mask operand. Handle both all-true/false and interleaved mask.
512538 unsigned LaneMaskLen = NumStoredElements / Factor;
513- Mask = getMask (VPStore-> getMaskParam ( ), Factor,
539+ Mask = getMask (getMaskOperand (II ), Factor,
514540 ElementCount::getFixed (LaneMaskLen));
515541 if (!Mask)
516542 return false ;
517543
518- LLVM_DEBUG (dbgs () << " IA: Found an interleaved vp.store: " << *Store
519- << " \n " );
520-
521- } else {
522- LLVM_DEBUG (dbgs () << " IA: Found an interleaved store: " << *Store << " \n " );
544+ LLVM_DEBUG (dbgs () << " IA: Found an interleaved vp.store or masked.store: "
545+ << *Store << " \n " );
523546 }
524547
525548 // Try to create target specific intrinsics to replace the store and
@@ -592,19 +615,7 @@ bool InterleavedAccessImpl::lowerDeinterleaveIntrinsic(
592615 assert (II);
593616
594617 // Check mask operand. Handle both all-true/false and interleaved mask.
595- Value *WideMask;
596- switch (II->getIntrinsicID ()) {
597- default :
598- return false ;
599- case Intrinsic::vp_load:
600- WideMask = II->getOperand (1 );
601- break ;
602- case Intrinsic::masked_load:
603- WideMask = II->getOperand (2 );
604- break ;
605- }
606-
607- Mask = getMask (WideMask, Factor, getDeinterleavedVectorType (DI));
618+ Mask = getMask (getMaskOperand (II), Factor, getDeinterleavedVectorType (DI));
608619 if (!Mask)
609620 return false ;
610621
@@ -642,18 +653,7 @@ bool InterleavedAccessImpl::lowerInterleaveIntrinsic(
642653 Value *Mask = nullptr ;
643654 if (II) {
644655 // Check mask operand. Handle both all-true/false and interleaved mask.
645- Value *WideMask;
646- switch (II->getIntrinsicID ()) {
647- default :
648- return false ;
649- case Intrinsic::vp_store:
650- WideMask = II->getOperand (2 );
651- break ;
652- case Intrinsic::masked_store:
653- WideMask = II->getOperand (3 );
654- break ;
655- }
656- Mask = getMask (WideMask, Factor,
656+ Mask = getMask (getMaskOperand (II), Factor,
657657 cast<VectorType>(InterleaveValues[0 ]->getType ()));
658658 if (!Mask)
659659 return false ;
@@ -687,11 +687,13 @@ bool InterleavedAccessImpl::runOnFunction(Function &F) {
687687 using namespace PatternMatch ;
688688 for (auto &I : instructions (F)) {
689689 if (match (&I, m_CombineOr (m_Load (m_Value ()),
690- m_Intrinsic<Intrinsic::vp_load>())))
690+ m_Intrinsic<Intrinsic::vp_load>())) ||
691+ match (&I, m_Intrinsic<Intrinsic::masked_load>()))
691692 Changed |= lowerInterleavedLoad (&I, DeadInsts);
692693
693694 if (match (&I, m_CombineOr (m_Store (m_Value (), m_Value ()),
694- m_Intrinsic<Intrinsic::vp_store>())))
695+ m_Intrinsic<Intrinsic::vp_store>())) ||
696+ match (&I, m_Intrinsic<Intrinsic::masked_store>()))
695697 Changed |= lowerInterleavedStore (&I, DeadInsts);
696698
697699 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
0 commit comments