@@ -2915,10 +2915,21 @@ static Value *interleaveVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vals,
29152915 // Scalable vectors cannot use arbitrary shufflevectors (only splats), so
29162916 // must use intrinsics to interleave.
29172917 if (VecTy->isScalableTy ()) {
2918- VectorType *WideVecTy = VectorType::getDoubleElementsVectorType (VecTy);
2919- return Builder.CreateIntrinsic (WideVecTy, Intrinsic::vector_interleave2,
2920- Vals,
2921- /* FMFSource=*/ nullptr , Name);
2918+ assert (isPowerOf2_32 (Factor) && " Unsupported interleave factor for "
2919+ " scalable vectors, must be power of 2" );
2920+ SmallVector<Value *> InterleavingValues (Vals);
2921+ // When interleaving, the number of values will be shrunk until we have the
2922+ // single final interleaved value.
2923+ auto *InterleaveTy = cast<VectorType>(InterleavingValues[0 ]->getType ());
2924+ for (unsigned Midpoint = Factor / 2 ; Midpoint > 0 ; Midpoint /= 2 ) {
2925+ InterleaveTy = VectorType::getDoubleElementsVectorType (InterleaveTy);
2926+ for (unsigned I = 0 ; I < Midpoint; ++I)
2927+ InterleavingValues[I] = Builder.CreateIntrinsic (
2928+ InterleaveTy, Intrinsic::vector_interleave2,
2929+ {InterleavingValues[I], InterleavingValues[Midpoint + I]},
2930+ /* FMFSource=*/ nullptr , Name);
2931+ }
2932+ return InterleavingValues[0 ];
29222933 }
29232934
29242935 // Fixed length. Start by concatenating all vectors into a wide vector.
@@ -3004,15 +3015,11 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
30043015 &InterleaveFactor](Value *MaskForGaps) -> Value * {
30053016 if (State.VF .isScalable ()) {
30063017 assert (!MaskForGaps && " Interleaved groups with gaps are not supported." );
3007- assert (InterleaveFactor == 2 &&
3018+ assert (isPowerOf2_32 ( InterleaveFactor) &&
30083019 " Unsupported deinterleave factor for scalable vectors" );
30093020 auto *ResBlockInMask = State.get (BlockInMask);
3010- SmallVector<Value *, 2 > Ops = {ResBlockInMask, ResBlockInMask};
3011- auto *MaskTy = VectorType::get (State.Builder .getInt1Ty (),
3012- State.VF .getKnownMinValue () * 2 , true );
3013- return State.Builder .CreateIntrinsic (
3014- MaskTy, Intrinsic::vector_interleave2, Ops,
3015- /* FMFSource=*/ nullptr , " interleaved.mask" );
3021+ SmallVector<Value *> Ops (InterleaveFactor, ResBlockInMask);
3022+ return interleaveVectors (State.Builder , Ops, " interleaved.mask" );
30163023 }
30173024
30183025 if (!BlockInMask)
@@ -3052,22 +3059,48 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
30523059 ArrayRef<VPValue *> VPDefs = definedValues ();
30533060 const DataLayout &DL = State.CFG .PrevBB ->getDataLayout ();
30543061 if (VecTy->isScalableTy ()) {
3055- assert (InterleaveFactor == 2 &&
3062+ assert (isPowerOf2_32 ( InterleaveFactor) &&
30563063 " Unsupported deinterleave factor for scalable vectors" );
30573064
3058- // Scalable vectors cannot use arbitrary shufflevectors (only splats),
3059- // so must use intrinsics to deinterleave.
3060- Value *DI = State.Builder .CreateIntrinsic (
3061- Intrinsic::vector_deinterleave2, VecTy, NewLoad,
3062- /* FMFSource=*/ nullptr , " strided.vec" );
3063- unsigned J = 0 ;
3064- for (unsigned I = 0 ; I < InterleaveFactor; ++I) {
3065- Instruction *Member = Group->getMember (I);
3065+ // Scalable vectors cannot use arbitrary shufflevectors (only splats),
3066+ // so must use intrinsics to deinterleave.
3067+ SmallVector<Value *> DeinterleavedValues (InterleaveFactor);
3068+ DeinterleavedValues[0 ] = NewLoad;
3069+ // For the case of InterleaveFactor > 2, we will have to do recursive
3070+ // deinterleaving, because the current available deinterleave intrinsic
3071+ // supports only Factor of 2, otherwise it will bailout after first
3072+ // iteration.
3073+ // When deinterleaving, the number of values will double until we
3074+ // have "InterleaveFactor".
3075+ for (unsigned NumVectors = 1 ; NumVectors < InterleaveFactor;
3076+ NumVectors *= 2 ) {
3077+ // Deinterleave the elements within the vector
3078+ SmallVector<Value *> TempDeinterleavedValues (NumVectors);
3079+ for (unsigned I = 0 ; I < NumVectors; ++I) {
3080+ auto *DiTy = DeinterleavedValues[I]->getType ();
3081+ TempDeinterleavedValues[I] = State.Builder .CreateIntrinsic (
3082+ Intrinsic::vector_deinterleave2, DiTy, DeinterleavedValues[I],
3083+ /* FMFSource=*/ nullptr , " strided.vec" );
3084+ }
3085+ // Extract the deinterleaved values:
3086+ for (unsigned I = 0 ; I < 2 ; ++I)
3087+ for (unsigned J = 0 ; J < NumVectors; ++J)
3088+ DeinterleavedValues[NumVectors * I + J] =
3089+ State.Builder .CreateExtractValue (TempDeinterleavedValues[J], I);
3090+ }
30663091
3067- if (!Member)
3092+ #ifndef NDEBUG
3093+ for (Value *Val : DeinterleavedValues)
3094+ assert (Val && " NULL Deinterleaved Value" );
3095+ #endif
3096+ for (unsigned I = 0 , J = 0 ; I < InterleaveFactor; ++I) {
3097+ Instruction *Member = Group->getMember (I);
3098+ Value *StridedVec = DeinterleavedValues[I];
3099+ if (!Member) {
3100+ // This value is not needed as it's not used
3101+ cast<Instruction>(StridedVec)->eraseFromParent ();
30683102 continue ;
3069-
3070- Value *StridedVec = State.Builder .CreateExtractValue (DI, I);
3103+ }
30713104 // If this member has different type, cast the result type.
30723105 if (Member->getType () != ScalarTy) {
30733106 VectorType *OtherVTy = VectorType::get (Member->getType (), State.VF );
0 commit comments