@@ -3130,21 +3130,66 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
31303130 return MadeChanges;
31313131}
31323132
3133+ // / For a given chain of patterns of the following form:
3134+ // /
3135+ // / ```
3136+ // / %1 = shufflevector <n x ty1> %0, <n x ty1> poison <n x ty2> mask
3137+ // /
3138+ // / %2 = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %0, <n x
3139+ // / ty1> %1)
3140+ // / OR
3141+ // / %2 = add/mul/or/and/xor <n x ty1> %0, %1
3142+ // /
3143+ // / %3 = shufflevector <n x ty1> %2, <n x ty1> poison <n x ty2> mask
3144+ // / ...
3145+ // / ...
3146+ // / %(i - 1) = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %(i -
3147+ // / 3), <n x ty1> %(i - 2)
3148+ // / OR
3149+ // / %(i - 1) = add/mul/or/and/xor <n x ty1> %(i - 3), %(i - 2)
3150+ // /
3151+ // / %(i) = extractelement <n x ty1> %(i - 1), 0
3152+ // / ```
3153+ // /
3154+ // / Where:
3155+ // / `mask` follows a partition pattern:
3156+ // /
3157+ // / Ex:
3158+ // / [n = 8, p = poison]
3159+ // /
3160+ // / 4 5 6 7 | p p p p
3161+ // / 2 3 | p p p p p p
3162+ // / 1 | p p p p p p p
3163+ // /
3164+ // / For powers of 2, there's a consistent pattern, but for other cases
3165+ // / the parity of the current half value at each step decides the
3166+ // / next partition half (see `ExpectedParityMask` for more logical details
3167+ // / in generalising this).
3168+ // /
3169+ // / Ex:
3170+ // / [n = 6]
3171+ // /
3172+ // / 3 4 5 | p p p
3173+ // / 1 2 | p p p p
3174+ // / 1 | p p p p p
31333175bool VectorCombine::foldShuffleChainsToReduce (Instruction &I) {
3134- auto *EEI = dyn_cast<ExtractElementInst>(&I);
3135- if (!EEI)
3136- return false ;
3137-
3176+ // Going bottom-up for the pattern.
31383177 std::queue<Value *> InstWorklist;
3139- Value *InitEEV = nullptr ;
3140- Intrinsic::ID CommonOp = 0 ;
3178+ InstructionCost OrigCost = 0 ;
3179+
3180+ // Common instruction operation after each shuffle op.
3181+ std::optional<unsigned int > CommonCallOp = std::nullopt ;
3182+ std::optional<Instruction::BinaryOps> CommonBinOp = std::nullopt ;
31413183
3142- bool IsFirstCallInst = true ;
3143- bool ShouldBeCallInst = true ;
3184+ bool IsFirstCallOrBinInst = true ;
3185+ bool ShouldBeCallOrBinInst = true ;
31443186
3187+ // This stores the last used instructions for shuffle/common op.
3188+ //
3189+ // PrevVecV[2] stores the first vector from extract element instruction,
3190+ // while PrevVecV[0] / PrevVecV[1] store the last two simultaneous
3191+ // instructions from either shuffle/common op.
31453192 SmallVector<Value *, 3 > PrevVecV (3 , nullptr );
3146- int64_t ShuffleMaskHalf = -1 , ExpectedShuffleMaskHalf = 1 ;
3147- int64_t VecSize = -1 ;
31483193
31493194 Value *VecOp;
31503195 if (!match (&I, m_ExtractElt (m_Value (VecOp), m_Zero ())))
@@ -3154,141 +3199,186 @@ bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
31543199 if (!FVT)
31553200 return false ;
31563201
3157- VecSize = FVT->getNumElements ();
3158- if (VecSize < 2 || (VecSize % 2 ) != 0 )
3202+ int64_t VecSize = FVT->getNumElements ();
3203+ if (VecSize < 2 )
31593204 return false ;
31603205
3161- ShuffleMaskHalf = 1 ;
3162- PrevVecV[2 ] = VecOp;
3163- InitEEV = EEI;
3206+ // Number of levels would be ~log2(n), considering we always partition
3207+ // by half for this fold pattern.
3208+ unsigned int NumLevels = Log2_64_Ceil (VecSize), VisitedCnt = 0 ;
3209+ int64_t ShuffleMaskHalf = 1 , ExpectedParityMask = 0 ;
3210+
3211+ // This is how we generalise for all element sizes.
3212+ // At each step, if vector size is odd, we need non-poison
3213+ // values to cover the dominant half so we don't miss out on any element.
3214+ //
3215+ // This mask will help us retrieve this as we go from bottom to top:
3216+ //
3217+ // Mask Set -> N = N * 2 - 1
3218+ // Mask Unset -> N = N * 2
3219+ for (int Cur = VecSize, Mask = NumLevels - 1 ; Cur > 1 ;
3220+ Cur = (Cur + 1 ) / 2 , --Mask) {
3221+ if (Cur & 1 )
3222+ ExpectedParityMask |= (1ll << Mask);
3223+ }
31643224
3225+ PrevVecV[2 ] = VecOp;
31653226 InstWorklist.push (PrevVecV[2 ]);
31663227
31673228 while (!InstWorklist.empty ()) {
3168- Value *V = InstWorklist.front ();
3229+ Value *CI = InstWorklist.front ();
31693230 InstWorklist.pop ();
31703231
3171- auto *CI = dyn_cast<Instruction>(V);
3172- if (!CI)
3173- return false ;
3174-
3175- if (auto *CallI = dyn_cast<CallInst>(CI)) {
3176- if (!ShouldBeCallInst || !PrevVecV[2 ])
3232+ if (auto *II = dyn_cast<IntrinsicInst>(CI)) {
3233+ if (!ShouldBeCallOrBinInst)
31773234 return false ;
31783235
3179- if (!IsFirstCallInst &&
3236+ if (!IsFirstCallOrBinInst &&
31803237 any_of (PrevVecV, [](Value *VecV) { return VecV == nullptr ; }))
31813238 return false ;
31823239
3183- if (CallI != (IsFirstCallInst ? PrevVecV[2 ] : PrevVecV[0 ]))
3240+ // For the first found call/bin op, the vector has to come from the
3241+ // extract element op.
3242+ if (II != (IsFirstCallOrBinInst ? PrevVecV[2 ] : PrevVecV[0 ]))
31843243 return false ;
3185- IsFirstCallInst = false ;
3244+ IsFirstCallOrBinInst = false ;
31863245
3187- auto *II = dyn_cast<IntrinsicInst>(CallI);
3188- if (!II)
3189- return false ;
3190-
3191- if (!CommonOp)
3192- CommonOp = II->getIntrinsicID ();
3193- if (II->getIntrinsicID () != CommonOp)
3246+ if (!CommonCallOp)
3247+ CommonCallOp = II->getIntrinsicID ();
3248+ if (II->getIntrinsicID () != *CommonCallOp)
31943249 return false ;
31953250
31963251 switch (II->getIntrinsicID ()) {
31973252 case Intrinsic::umin:
31983253 case Intrinsic::umax:
31993254 case Intrinsic::smin:
32003255 case Intrinsic::smax: {
3201- auto *Op0 = CallI ->getOperand (0 );
3202- auto *Op1 = CallI ->getOperand (1 );
3256+ auto *Op0 = II ->getOperand (0 );
3257+ auto *Op1 = II ->getOperand (1 );
32033258 PrevVecV[0 ] = Op0;
32043259 PrevVecV[1 ] = Op1;
32053260 break ;
32063261 }
32073262 default :
32083263 return false ;
32093264 }
3210- ShouldBeCallInst ^= 1 ;
3265+ ShouldBeCallOrBinInst ^= 1 ;
3266+
3267+ IntrinsicCostAttributes ICA (
3268+ *CommonCallOp, II->getType (),
3269+ {PrevVecV[0 ]->getType (), PrevVecV[1 ]->getType ()});
3270+ OrigCost += TTI.getIntrinsicInstrCost (ICA, CostKind);
32113271
3272+ // We may need a swap here since it can be (a, b) or (b, a)
3273+ // and accordingly change as we go up.
32123274 if (!isa<ShuffleVectorInst>(PrevVecV[1 ]))
32133275 std::swap (PrevVecV[0 ], PrevVecV[1 ]);
32143276 InstWorklist.push (PrevVecV[1 ]);
32153277 InstWorklist.push (PrevVecV[0 ]);
3216- } else if (auto *SVInst = dyn_cast<ShuffleVectorInst>(CI)) {
3217- if (ShouldBeCallInst ||
3278+ } else if (auto *BinOp = dyn_cast<BinaryOperator>(CI)) {
3279+ // Similar logic for bin ops.
3280+
3281+ if (!ShouldBeCallOrBinInst)
3282+ return false ;
3283+
3284+ if (!IsFirstCallOrBinInst &&
32183285 any_of (PrevVecV, [](Value *VecV) { return VecV == nullptr ; }))
32193286 return false ;
32203287
3221- if (SVInst != PrevVecV[1 ])
3288+ if (BinOp != (IsFirstCallOrBinInst ? PrevVecV[2 ] : PrevVecV[0 ]))
3289+ return false ;
3290+ IsFirstCallOrBinInst = false ;
3291+
3292+ if (!CommonBinOp)
3293+ CommonBinOp = BinOp->getOpcode ();
3294+
3295+ if (BinOp->getOpcode () != *CommonBinOp)
32223296 return false ;
32233297
3224- auto *ShuffleVec = SVInst->getOperand (0 );
3225- if (!ShuffleVec || ShuffleVec != PrevVecV[0 ])
3298+ switch (*CommonBinOp) {
3299+ case BinaryOperator::Add:
3300+ case BinaryOperator::Mul:
3301+ case BinaryOperator::Or:
3302+ case BinaryOperator::And:
3303+ case BinaryOperator::Xor: {
3304+ auto *Op0 = BinOp->getOperand (0 );
3305+ auto *Op1 = BinOp->getOperand (1 );
3306+ PrevVecV[0 ] = Op0;
3307+ PrevVecV[1 ] = Op1;
3308+ break ;
3309+ }
3310+ default :
32263311 return false ;
3312+ }
3313+ ShouldBeCallOrBinInst ^= 1 ;
32273314
3228- SmallVector< int > CurMask;
3229- SVInst-> getShuffleMask (CurMask );
3315+ OrigCost +=
3316+ TTI. getArithmeticInstrCost (*CommonBinOp, BinOp-> getType (), CostKind );
32303317
3231- if (ShuffleMaskHalf != ExpectedShuffleMaskHalf)
3318+ if (!isa<ShuffleVectorInst>(PrevVecV[1 ]))
3319+ std::swap (PrevVecV[0 ], PrevVecV[1 ]);
3320+ InstWorklist.push (PrevVecV[1 ]);
3321+ InstWorklist.push (PrevVecV[0 ]);
3322+ } else if (auto *SVInst = dyn_cast<ShuffleVectorInst>(CI)) {
3323+ // We shouldn't have any null values in the previous vectors,
3324+ // is so, there was a mismatch in pattern.
3325+ if (ShouldBeCallOrBinInst ||
3326+ any_of (PrevVecV, [](Value *VecV) { return VecV == nullptr ; }))
3327+ return false ;
3328+
3329+ if (SVInst != PrevVecV[1 ])
3330+ return false ;
3331+
3332+ ArrayRef<int > CurMask;
3333+ if (!match (SVInst, m_Shuffle (m_Specific (PrevVecV[0 ]), m_Poison (),
3334+ m_Mask (CurMask))))
32323335 return false ;
3233- ExpectedShuffleMaskHalf *= 2 ;
32343336
3337+ // Subtract the parity mask when checking the condition.
32353338 for (int Mask = 0 , MaskSize = CurMask.size (); Mask != MaskSize; ++Mask) {
3236- if (Mask < ShuffleMaskHalf && CurMask[Mask] != ShuffleMaskHalf + Mask)
3339+ if (Mask < ShuffleMaskHalf &&
3340+ CurMask[Mask] != ShuffleMaskHalf + Mask - (ExpectedParityMask & 1 ))
32373341 return false ;
32383342 if (Mask >= ShuffleMaskHalf && CurMask[Mask] != -1 )
32393343 return false ;
32403344 }
3345+
3346+ // Update mask values.
32413347 ShuffleMaskHalf *= 2 ;
3242- if (ExpectedShuffleMaskHalf == VecSize)
3348+ ShuffleMaskHalf -= (ExpectedParityMask & 1 );
3349+ ExpectedParityMask >>= 1 ;
3350+
3351+ OrigCost += TTI.getShuffleCost (TargetTransformInfo::SK_PermuteSingleSrc,
3352+ SVInst->getType (), SVInst->getType (),
3353+ CurMask, CostKind);
3354+
3355+ VisitedCnt += 1 ;
3356+ if (!ExpectedParityMask && VisitedCnt == NumLevels)
32433357 break ;
3244- ShouldBeCallInst ^= 1 ;
3358+
3359+ ShouldBeCallOrBinInst ^= 1 ;
32453360 } else {
32463361 return false ;
32473362 }
32483363 }
32493364
3250- if (ShouldBeCallInst)
3365+ // Pattern should end with a shuffle op.
3366+ if (ShouldBeCallOrBinInst)
32513367 return false ;
32523368
3253- assert (VecSize != -1 && ExpectedShuffleMaskHalf == VecSize &&
3254- " Expected Match for Vector Size and Mask Half" );
3369+ assert (VecSize != -1 && " Expected Match for Vector Size" );
32553370
32563371 Value *FinalVecV = PrevVecV[0 ];
3257- auto *FinalVecVTy = dyn_cast<FixedVectorType>(FinalVecV->getType ());
3258-
3259- if (!InitEEV || !FinalVecV)
3372+ if (!FinalVecV)
32603373 return false ;
32613374
3262- assert ( FinalVecVTy && " Expected non-null value for Vector Type " );
3375+ auto * FinalVecVTy = cast<FixedVectorType>(FinalVecV-> getType () );
32633376
3264- Intrinsic::ID ReducedOp = 0 ;
3265- switch (CommonOp) {
3266- case Intrinsic::umin:
3267- ReducedOp = Intrinsic::vector_reduce_umin;
3268- break ;
3269- case Intrinsic::umax:
3270- ReducedOp = Intrinsic::vector_reduce_umax;
3271- break ;
3272- case Intrinsic::smin:
3273- ReducedOp = Intrinsic::vector_reduce_smin;
3274- break ;
3275- case Intrinsic::smax:
3276- ReducedOp = Intrinsic::vector_reduce_smax;
3277- break ;
3278- default :
3377+ Intrinsic::ID ReducedOp =
3378+ (CommonCallOp ? getMinMaxReductionIntrinsicID (*CommonCallOp)
3379+ : getReductionForBinop (*CommonBinOp));
3380+ if (!ReducedOp)
32793381 return false ;
3280- }
3281-
3282- InstructionCost OrigCost = 0 ;
3283- unsigned int NumLevels = Log2_64 (VecSize);
3284-
3285- for (unsigned int Level = 0 ; Level < NumLevels; ++Level) {
3286- OrigCost += TTI.getShuffleCost (TargetTransformInfo::SK_PermuteSingleSrc,
3287- FinalVecVTy, FinalVecVTy);
3288- OrigCost += TTI.getArithmeticInstrCost (Instruction::ICmp, FinalVecVTy);
3289- }
3290- OrigCost += TTI.getVectorInstrCost (Instruction::ExtractElement, FinalVecVTy,
3291- CostKind, 0 );
32923382
32933383 IntrinsicCostAttributes ICA (ReducedOp, FinalVecVTy, {FinalVecV});
32943384 InstructionCost NewCost = TTI.getIntrinsicInstrCost (ICA, CostKind);
@@ -3298,7 +3388,7 @@ bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
32983388
32993389 auto *ReducedResult =
33003390 Builder.CreateIntrinsic (ReducedOp, {FinalVecV->getType ()}, {FinalVecV});
3301- replaceValue (*InitEEV , *ReducedResult);
3391+ replaceValue (I , *ReducedResult);
33023392
33033393 return true ;
33043394}
@@ -4391,8 +4481,8 @@ bool VectorCombine::run() {
43914481 return true ;
43924482 break ;
43934483 case Instruction::ExtractElement:
4394- MadeChange |= foldShuffleChainsToReduce (I);
4395- break ;
4484+ if ( foldShuffleChainsToReduce (I))
4485+ return true ;
43964486 case Instruction::ICmp:
43974487 case Instruction::FCmp:
43984488 if (foldExtractExtract (I))
0 commit comments