@@ -135,6 +135,7 @@ class VectorCombine {
135135 bool foldShuffleOfIntrinsics (Instruction &I);
136136 bool foldShuffleToIdentity (Instruction &I);
137137 bool foldShuffleFromReductions (Instruction &I);
138+ bool foldShuffleChainsToReduce (Instruction &I);
138139 bool foldCastFromReductions (Instruction &I);
139140 bool foldSelectShuffle (Instruction &I, bool FromReduction = false );
140141 bool foldInterleaveIntrinsics (Instruction &I);
@@ -3136,6 +3137,267 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
31363137 return MadeChanges;
31373138}
31383139
3140+ // / For a given chain of patterns of the following form:
3141+ // /
3142+ // / ```
3143+ // / %1 = shufflevector <n x ty1> %0, <n x ty1> poison <n x ty2> mask
3144+ // /
3145+ // / %2 = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %0, <n x
3146+ // / ty1> %1)
3147+ // / OR
3148+ // / %2 = add/mul/or/and/xor <n x ty1> %0, %1
3149+ // /
3150+ // / %3 = shufflevector <n x ty1> %2, <n x ty1> poison <n x ty2> mask
3151+ // / ...
3152+ // / ...
3153+ // / %(i - 1) = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %(i -
3154+ // / 3), <n x ty1> %(i - 2)
3155+ // / OR
3156+ // / %(i - 1) = add/mul/or/and/xor <n x ty1> %(i - 3), %(i - 2)
3157+ // /
3158+ // / %(i) = extractelement <n x ty1> %(i - 1), 0
3159+ // / ```
3160+ // /
3161+ // / Where:
3162+ // / `mask` follows a partition pattern:
3163+ // /
3164+ // / Ex:
3165+ // / [n = 8, p = poison]
3166+ // /
3167+ // / 4 5 6 7 | p p p p
3168+ // / 2 3 | p p p p p p
3169+ // / 1 | p p p p p p p
3170+ // /
3171+ // / For powers of 2, there's a consistent pattern, but for other cases
3172+ // / the parity of the current half value at each step decides the
3173+ // / next partition half (see `ExpectedParityMask` for more logical details
3174+ // / in generalising this).
3175+ // /
3176+ // / Ex:
3177+ // / [n = 6]
3178+ // /
3179+ // / 3 4 5 | p p p
3180+ // / 1 2 | p p p p
3181+ // / 1 | p p p p p
3182+ bool VectorCombine::foldShuffleChainsToReduce (Instruction &I) {
3183+ // Going bottom-up for the pattern.
3184+ std::queue<Value *> InstWorklist;
3185+ InstructionCost OrigCost = 0 ;
3186+
3187+ // Common instruction operation after each shuffle op.
3188+ std::optional<unsigned int > CommonCallOp = std::nullopt ;
3189+ std::optional<Instruction::BinaryOps> CommonBinOp = std::nullopt ;
3190+
3191+ bool IsFirstCallOrBinInst = true ;
3192+ bool ShouldBeCallOrBinInst = true ;
3193+
3194+ // This stores the last used instructions for shuffle/common op.
3195+ //
3196+ // PrevVecV[0] / PrevVecV[1] store the last two simultaneous
3197+ // instructions from either shuffle/common op.
3198+ SmallVector<Value *, 2 > PrevVecV (2 , nullptr );
3199+
3200+ Value *VecOpEE;
3201+ if (!match (&I, m_ExtractElt (m_Value (VecOpEE), m_Zero ())))
3202+ return false ;
3203+
3204+ auto *FVT = dyn_cast<FixedVectorType>(VecOpEE->getType ());
3205+ if (!FVT)
3206+ return false ;
3207+
3208+ int64_t VecSize = FVT->getNumElements ();
3209+ if (VecSize < 2 )
3210+ return false ;
3211+
3212+ // Number of levels would be ~log2(n), considering we always partition
3213+ // by half for this fold pattern.
3214+ unsigned int NumLevels = Log2_64_Ceil (VecSize), VisitedCnt = 0 ;
3215+ int64_t ShuffleMaskHalf = 1 , ExpectedParityMask = 0 ;
3216+
3217+ // This is how we generalise for all element sizes.
3218+ // At each step, if vector size is odd, we need non-poison
3219+ // values to cover the dominant half so we don't miss out on any element.
3220+ //
3221+ // This mask will help us retrieve this as we go from bottom to top:
3222+ //
3223+ // Mask Set -> N = N * 2 - 1
3224+ // Mask Unset -> N = N * 2
3225+ for (int Cur = VecSize, Mask = NumLevels - 1 ; Cur > 1 ;
3226+ Cur = (Cur + 1 ) / 2 , --Mask) {
3227+ if (Cur & 1 )
3228+ ExpectedParityMask |= (1ll << Mask);
3229+ }
3230+
3231+ InstWorklist.push (VecOpEE);
3232+
3233+ while (!InstWorklist.empty ()) {
3234+ Value *CI = InstWorklist.front ();
3235+ InstWorklist.pop ();
3236+
3237+ if (auto *II = dyn_cast<IntrinsicInst>(CI)) {
3238+ if (!ShouldBeCallOrBinInst)
3239+ return false ;
3240+
3241+ if (!IsFirstCallOrBinInst &&
3242+ any_of (PrevVecV, [](Value *VecV) { return VecV == nullptr ; }))
3243+ return false ;
3244+
3245+ // For the first found call/bin op, the vector has to come from the
3246+ // extract element op.
3247+ if (II != (IsFirstCallOrBinInst ? VecOpEE : PrevVecV[0 ]))
3248+ return false ;
3249+ IsFirstCallOrBinInst = false ;
3250+
3251+ if (!CommonCallOp)
3252+ CommonCallOp = II->getIntrinsicID ();
3253+ if (II->getIntrinsicID () != *CommonCallOp)
3254+ return false ;
3255+
3256+ switch (II->getIntrinsicID ()) {
3257+ case Intrinsic::umin:
3258+ case Intrinsic::umax:
3259+ case Intrinsic::smin:
3260+ case Intrinsic::smax: {
3261+ auto *Op0 = II->getOperand (0 );
3262+ auto *Op1 = II->getOperand (1 );
3263+ PrevVecV[0 ] = Op0;
3264+ PrevVecV[1 ] = Op1;
3265+ break ;
3266+ }
3267+ default :
3268+ return false ;
3269+ }
3270+ ShouldBeCallOrBinInst ^= 1 ;
3271+
3272+ IntrinsicCostAttributes ICA (
3273+ *CommonCallOp, II->getType (),
3274+ {PrevVecV[0 ]->getType (), PrevVecV[1 ]->getType ()});
3275+ OrigCost += TTI.getIntrinsicInstrCost (ICA, CostKind);
3276+
3277+ // We may need a swap here since it can be (a, b) or (b, a)
3278+ // and accordingly change as we go up.
3279+ if (!isa<ShuffleVectorInst>(PrevVecV[1 ]))
3280+ std::swap (PrevVecV[0 ], PrevVecV[1 ]);
3281+ InstWorklist.push (PrevVecV[1 ]);
3282+ InstWorklist.push (PrevVecV[0 ]);
3283+ } else if (auto *BinOp = dyn_cast<BinaryOperator>(CI)) {
3284+ // Similar logic for bin ops.
3285+
3286+ if (!ShouldBeCallOrBinInst)
3287+ return false ;
3288+
3289+ if (!IsFirstCallOrBinInst &&
3290+ any_of (PrevVecV, [](Value *VecV) { return VecV == nullptr ; }))
3291+ return false ;
3292+
3293+ if (BinOp != (IsFirstCallOrBinInst ? VecOpEE : PrevVecV[0 ]))
3294+ return false ;
3295+ IsFirstCallOrBinInst = false ;
3296+
3297+ if (!CommonBinOp)
3298+ CommonBinOp = BinOp->getOpcode ();
3299+
3300+ if (BinOp->getOpcode () != *CommonBinOp)
3301+ return false ;
3302+
3303+ switch (*CommonBinOp) {
3304+ case BinaryOperator::Add:
3305+ case BinaryOperator::Mul:
3306+ case BinaryOperator::Or:
3307+ case BinaryOperator::And:
3308+ case BinaryOperator::Xor: {
3309+ auto *Op0 = BinOp->getOperand (0 );
3310+ auto *Op1 = BinOp->getOperand (1 );
3311+ PrevVecV[0 ] = Op0;
3312+ PrevVecV[1 ] = Op1;
3313+ break ;
3314+ }
3315+ default :
3316+ return false ;
3317+ }
3318+ ShouldBeCallOrBinInst ^= 1 ;
3319+
3320+ OrigCost +=
3321+ TTI.getArithmeticInstrCost (*CommonBinOp, BinOp->getType (), CostKind);
3322+
3323+ if (!isa<ShuffleVectorInst>(PrevVecV[1 ]))
3324+ std::swap (PrevVecV[0 ], PrevVecV[1 ]);
3325+ InstWorklist.push (PrevVecV[1 ]);
3326+ InstWorklist.push (PrevVecV[0 ]);
3327+ } else if (auto *SVInst = dyn_cast<ShuffleVectorInst>(CI)) {
3328+ // We shouldn't have any null values in the previous vectors,
3329+ // is so, there was a mismatch in pattern.
3330+ if (ShouldBeCallOrBinInst ||
3331+ any_of (PrevVecV, [](Value *VecV) { return VecV == nullptr ; }))
3332+ return false ;
3333+
3334+ if (SVInst != PrevVecV[1 ])
3335+ return false ;
3336+
3337+ ArrayRef<int > CurMask;
3338+ if (!match (SVInst, m_Shuffle (m_Specific (PrevVecV[0 ]), m_Poison (),
3339+ m_Mask (CurMask))))
3340+ return false ;
3341+
3342+ // Subtract the parity mask when checking the condition.
3343+ for (int Mask = 0 , MaskSize = CurMask.size (); Mask != MaskSize; ++Mask) {
3344+ if (Mask < ShuffleMaskHalf &&
3345+ CurMask[Mask] != ShuffleMaskHalf + Mask - (ExpectedParityMask & 1 ))
3346+ return false ;
3347+ if (Mask >= ShuffleMaskHalf && CurMask[Mask] != -1 )
3348+ return false ;
3349+ }
3350+
3351+ // Update mask values.
3352+ ShuffleMaskHalf *= 2 ;
3353+ ShuffleMaskHalf -= (ExpectedParityMask & 1 );
3354+ ExpectedParityMask >>= 1 ;
3355+
3356+ OrigCost += TTI.getShuffleCost (TargetTransformInfo::SK_PermuteSingleSrc,
3357+ SVInst->getType (), SVInst->getType (),
3358+ CurMask, CostKind);
3359+
3360+ VisitedCnt += 1 ;
3361+ if (!ExpectedParityMask && VisitedCnt == NumLevels)
3362+ break ;
3363+
3364+ ShouldBeCallOrBinInst ^= 1 ;
3365+ } else {
3366+ return false ;
3367+ }
3368+ }
3369+
3370+ // Pattern should end with a shuffle op.
3371+ if (ShouldBeCallOrBinInst)
3372+ return false ;
3373+
3374+ assert (VecSize != -1 && " Expected Match for Vector Size" );
3375+
3376+ Value *FinalVecV = PrevVecV[0 ];
3377+ if (!FinalVecV)
3378+ return false ;
3379+
3380+ auto *FinalVecVTy = cast<FixedVectorType>(FinalVecV->getType ());
3381+
3382+ Intrinsic::ID ReducedOp =
3383+ (CommonCallOp ? getMinMaxReductionIntrinsicID (*CommonCallOp)
3384+ : getReductionForBinop (*CommonBinOp));
3385+ if (!ReducedOp)
3386+ return false ;
3387+
3388+ IntrinsicCostAttributes ICA (ReducedOp, FinalVecVTy, {FinalVecV});
3389+ InstructionCost NewCost = TTI.getIntrinsicInstrCost (ICA, CostKind);
3390+
3391+ if (NewCost >= OrigCost)
3392+ return false ;
3393+
3394+ auto *ReducedResult =
3395+ Builder.CreateIntrinsic (ReducedOp, {FinalVecV->getType ()}, {FinalVecV});
3396+ replaceValue (I, *ReducedResult);
3397+
3398+ return true ;
3399+ }
3400+
31393401// / Determine if its more efficient to fold:
31403402// / reduce(trunc(x)) -> trunc(reduce(x)).
31413403// / reduce(sext(x)) -> sext(reduce(x)).
@@ -4223,6 +4485,9 @@ bool VectorCombine::run() {
42234485 if (foldCastFromReductions (I))
42244486 return true ;
42254487 break ;
4488+ case Instruction::ExtractElement:
4489+ if (foldShuffleChainsToReduce (I))
4490+ return true ;
42264491 case Instruction::ICmp:
42274492 case Instruction::FCmp:
42284493 if (foldExtractExtract (I))
0 commit comments