@@ -630,23 +630,41 @@ getVectorDeinterleaveFactor(IntrinsicInst *II,
630630 return true ;
631631}
632632
633- // Return nullptr if the value corresponds to a all-true mask. Otherwise,
634- // return the value that is corresponded to a deinterleaved mask.
635- static Value *getMask (Value *WideMask, unsigned Factor) {
633+ // Return the corresponded deinterleaved mask, or nullptr if there is no valid
634+ // mask.
635+ static Value *getMask (Value *WideMask, unsigned Factor,
636+ VectorType *LeafValueTy) {
637+ Value *MaskVal = nullptr ;
638+
636639 using namespace llvm ::PatternMatch;
637640 if (auto *IMI = dyn_cast<IntrinsicInst>(WideMask)) {
638641 SmallVector<Value *, 8 > Operands;
639642 SmallVector<Instruction *, 8 > DeadInsts;
640643 if (getVectorInterleaveFactor (IMI, Operands, DeadInsts)) {
641644 assert (!Operands.empty ());
642645 if (Operands.size () == Factor && llvm::all_equal (Operands))
643- return Operands[0 ];
646+ MaskVal = Operands[0 ];
644647 }
645648 }
646- if (match (WideMask, m_AllOnes ()))
647- return WideMask;
648649
649- return nullptr ;
650+ if (match (WideMask, m_AllOnes ())) {
651+ // Scale the vector length.
652+ ElementCount OrigEC =
653+ cast<VectorType>(WideMask->getType ())->getElementCount ();
654+ MaskVal =
655+ ConstantVector::getSplat (OrigEC.divideCoefficientBy (Factor),
656+ cast<Constant>(WideMask)->getSplatValue ());
657+ }
658+
659+ if (MaskVal) {
660+ // Check if the vector length of mask matches that of the leaf values.
661+ auto *MaskTy = cast<VectorType>(MaskVal->getType ());
662+ if (!MaskTy->getElementType ()->isIntegerTy (/* Bitwidth=*/ 1 ) ||
663+ MaskTy->getElementCount () != LeafValueTy->getElementCount ())
664+ return nullptr ;
665+ }
666+
667+ return MaskVal;
650668}
651669
652670bool InterleavedAccessImpl::lowerDeinterleaveIntrinsic (
@@ -668,7 +686,8 @@ bool InterleavedAccessImpl::lowerDeinterleaveIntrinsic(
668686 return false ;
669687 // Check mask operand. Handle both all-true and interleaved mask.
670688 Value *WideMask = VPLoad->getOperand (1 );
671- Value *Mask = getMask (WideMask, Factor);
689+ Value *Mask = getMask (WideMask, Factor,
690+ cast<VectorType>(DeinterleaveValues[0 ]->getType ()));
672691 if (!Mask)
673692 return false ;
674693
@@ -720,7 +739,8 @@ bool InterleavedAccessImpl::lowerInterleaveIntrinsic(
720739 return false ;
721740
722741 Value *WideMask = VPStore->getOperand (2 );
723- Value *Mask = getMask (WideMask, Factor);
742+ Value *Mask = getMask (WideMask, Factor,
743+ cast<VectorType>(InterleaveValues[0 ]->getType ()));
724744 if (!Mask)
725745 return false ;
726746
0 commit comments