@@ -268,13 +268,19 @@ static Value *getMaskOperand(IntrinsicInst *II) {
268268 }
269269}
270270
271- // Return the corresponded deinterleaved mask, or nullptr if there is no valid
272- // mask.
273- static Value *getMask (Value *WideMask, unsigned Factor,
274- ElementCount LeafValueEC);
275-
276- static Value *getMask (Value *WideMask, unsigned Factor,
277- VectorType *LeafValueTy) {
271+ // Return a pair of
272+ // (1) The corresponded deinterleaved mask, or nullptr if there is no valid
273+ // mask.
274+ // (2) Some mask effectively skips a certain field, this element contains
275+ // the factor after taking such contraction into consideration. Note that
276+ // currently we only support skipping trailing fields. So if the "nominal"
277+ // factor was 5, you cannot only skip field 1 and 2, but you can skip field 3
278+ // and 4.
279+ static std::pair<Value *, unsigned > getMask (Value *WideMask, unsigned Factor,
280+ ElementCount LeafValueEC);
281+
282+ static std::pair<Value *, unsigned > getMask (Value *WideMask, unsigned Factor,
283+ VectorType *LeafValueTy) {
278284 return getMask (WideMask, Factor, LeafValueTy->getElementCount ());
279285}
280286
@@ -379,22 +385,25 @@ bool InterleavedAccessImpl::lowerInterleavedLoad(
379385 replaceBinOpShuffles (BinOpShuffles.getArrayRef (), Shuffles, Load);
380386
381387 Value *Mask = nullptr ;
388+ unsigned MaskFactor = Factor;
382389 if (LI) {
383390 LLVM_DEBUG (dbgs () << " IA: Found an interleaved load: " << *Load << " \n " );
384391 } else {
385392 // Check mask operand. Handle both all-true/false and interleaved mask.
386- Mask = getMask (getMaskOperand (II), Factor, VecTy);
393+ std::tie ( Mask, MaskFactor) = getMask (getMaskOperand (II), Factor, VecTy);
387394 if (!Mask)
388395 return false ;
389396
390397 LLVM_DEBUG (dbgs () << " IA: Found an interleaved vp.load or masked.load: "
391398 << *Load << " \n " );
399+ LLVM_DEBUG (dbgs () << " IA: With nominal factor " << Factor
400+ << " and mask factor " << MaskFactor << " \n " );
392401 }
393402
394403 // Try to create target specific intrinsics to replace the load and
395404 // shuffles.
396405 if (!TLI->lowerInterleavedLoad (cast<Instruction>(Load), Mask, Shuffles,
397- Indices, Factor))
406+ Indices, Factor, MaskFactor ))
398407 // If Extracts is not empty, tryReplaceExtracts made changes earlier.
399408 return !Extracts.empty () || BinOpShuffleChanged;
400409
@@ -536,8 +545,8 @@ bool InterleavedAccessImpl::lowerInterleavedStore(
536545 } else {
537546 // Check mask operand. Handle both all-true/false and interleaved mask.
538547 unsigned LaneMaskLen = NumStoredElements / Factor;
539- Mask = getMask (getMaskOperand (II), Factor,
540- ElementCount::getFixed (LaneMaskLen));
548+ std::tie ( Mask, std::ignore) = getMask (getMaskOperand (II), Factor,
549+ ElementCount::getFixed (LaneMaskLen));
541550 if (!Mask)
542551 return false ;
543552
@@ -556,34 +565,57 @@ bool InterleavedAccessImpl::lowerInterleavedStore(
556565 return true ;
557566}
558567
559- static Value *getMask (Value *WideMask, unsigned Factor,
560- ElementCount LeafValueEC) {
568+ static std::pair< Value *, unsigned > getMask (Value *WideMask, unsigned Factor,
569+ ElementCount LeafValueEC) {
561570 if (auto *IMI = dyn_cast<IntrinsicInst>(WideMask)) {
562571 if (unsigned F = getInterleaveIntrinsicFactor (IMI->getIntrinsicID ());
563572 F && F == Factor && llvm::all_equal (IMI->args ())) {
564- return IMI->getArgOperand (0 );
573+ return { IMI->getArgOperand (0 ), Factor} ;
565574 }
566575 }
567576
568577 if (auto *ConstMask = dyn_cast<Constant>(WideMask)) {
569578 if (auto *Splat = ConstMask->getSplatValue ())
570579 // All-ones or all-zeros mask.
571- return ConstantVector::getSplat (LeafValueEC, Splat);
580+ return { ConstantVector::getSplat (LeafValueEC, Splat), Factor} ;
572581
573582 if (LeafValueEC.isFixed ()) {
574583 unsigned LeafMaskLen = LeafValueEC.getFixedValue ();
584+ // First, check if the mask completely skips some of the factors / fields.
585+ APInt FactorMask (Factor, 0 );
586+ FactorMask.setAllBits ();
587+ for (unsigned F = 0U ; F < Factor; ++F) {
588+ unsigned Idx;
589+ for (Idx = 0U ; Idx < LeafMaskLen; ++Idx) {
590+ Constant *C = ConstMask->getAggregateElement (F + Idx * Factor);
591+ if (!C->isZeroValue ())
592+ break ;
593+ }
594+ // All mask bits on this field are zero, skipping it.
595+ if (Idx >= LeafMaskLen)
596+ FactorMask.clearBit (F);
597+ }
598+ // We currently only support skipping "trailing" factors / fields. So
599+ // given the original factor being 4, we can skip fields 2 and 3, but we
600+ // cannot only skip fields 1 and 2. If FactorMask does not match such
601+ // pattern, reset it.
602+ if (!FactorMask.isMask ())
603+ FactorMask.setAllBits ();
604+
575605 SmallVector<Constant *, 8 > LeafMask (LeafMaskLen, nullptr );
576606 // If this is a fixed-length constant mask, each lane / leaf has to
577607 // use the same mask. This is done by checking if every group with Factor
578608 // number of elements in the interleaved mask has homogeneous values.
579609 for (unsigned Idx = 0U ; Idx < LeafMaskLen * Factor; ++Idx) {
610+ if (!FactorMask[Idx % Factor])
611+ continue ;
580612 Constant *C = ConstMask->getAggregateElement (Idx);
581613 if (LeafMask[Idx / Factor] && LeafMask[Idx / Factor] != C)
582- return nullptr ;
614+ return { nullptr , Factor} ;
583615 LeafMask[Idx / Factor] = C;
584616 }
585617
586- return ConstantVector::get (LeafMask);
618+ return { ConstantVector::get (LeafMask), FactorMask. popcount ()} ;
587619 }
588620 }
589621
@@ -603,12 +635,13 @@ static Value *getMask(Value *WideMask, unsigned Factor,
603635 auto *LeafMaskTy =
604636 VectorType::get (Type::getInt1Ty (SVI->getContext ()), LeafValueEC);
605637 IRBuilder<> Builder (SVI);
606- return Builder.CreateExtractVector (LeafMaskTy, SVI->getOperand (0 ),
607- uint64_t (0 ));
638+ return {Builder.CreateExtractVector (LeafMaskTy, SVI->getOperand (0 ),
639+ uint64_t (0 )),
640+ Factor};
608641 }
609642 }
610643
611- return nullptr ;
644+ return { nullptr , Factor} ;
612645}
613646
614647bool InterleavedAccessImpl::lowerDeinterleaveIntrinsic (
@@ -639,7 +672,8 @@ bool InterleavedAccessImpl::lowerDeinterleaveIntrinsic(
639672 return false ;
640673
641674 // Check mask operand. Handle both all-true/false and interleaved mask.
642- Mask = getMask (getMaskOperand (II), Factor, getDeinterleavedVectorType (DI));
675+ std::tie (Mask, std::ignore) =
676+ getMask (getMaskOperand (II), Factor, getDeinterleavedVectorType (DI));
643677 if (!Mask)
644678 return false ;
645679
@@ -680,8 +714,9 @@ bool InterleavedAccessImpl::lowerInterleaveIntrinsic(
680714 II->getIntrinsicID () != Intrinsic::vp_store)
681715 return false ;
682716 // Check mask operand. Handle both all-true/false and interleaved mask.
683- Mask = getMask (getMaskOperand (II), Factor,
684- cast<VectorType>(InterleaveValues[0 ]->getType ()));
717+ std::tie (Mask, std::ignore) =
718+ getMask (getMaskOperand (II), Factor,
719+ cast<VectorType>(InterleaveValues[0 ]->getType ()));
685720 if (!Mask)
686721 return false ;
687722
0 commit comments