Skip to content

Commit 93c9684

Browse files
authored
[VectorCombine] New folding pattern for extract/binop/shuffle chains (llvm#145232)
Resolves llvm#144654 Part of llvm#143088 This adds a new `foldShuffleChainsToReduce` for horizontal reduction of patterns like: ```llvm define i16 @test_reduce_v8i16(<8 x i16> %a0) local_unnamed_addr #0 { %1 = shufflevector <8 x i16> %a0, <8 x i16> poison, <8 x i32> <i32 4, i32 5, i32 6, i32 7, i32 poison, i32 poison, i32 poison, i32 poison> %2 = tail call <8 x i16> @llvm.umin.v8i16(<8 x i16> %a0, <8 x i16> %1) %3 = shufflevector <8 x i16> %2, <8 x i16> poison, <8 x i32> <i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison> %4 = tail call <8 x i16> @llvm.umin.v8i16(<8 x i16> %2, <8 x i16> %3) %5 = shufflevector <8 x i16> %4, <8 x i16> poison, <8 x i32> <i32 1, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison> %6 = tail call <8 x i16> @llvm.umin.v8i16(<8 x i16> %4, <8 x i16> %5) %7 = extractelement <8 x i16> %6, i64 0 ret i16 %7 } ``` ...which can be reduced to a llvm.vector.reduce.umin.v8i16(%a0) intrinsic call. Similar transformation for other ops when costs permit to do so.
1 parent 3054e06 commit 93c9684

File tree

5 files changed

+678
-0
lines changed

5 files changed

+678
-0
lines changed

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,9 @@ LLVM_ABI bool canSinkOrHoistInst(Instruction &I, AAResults *AA,
371371
/// Returns the llvm.vector.reduce intrinsic that corresponds to the recurrence
372372
/// kind.
373373
LLVM_ABI constexpr Intrinsic::ID getReductionIntrinsicID(RecurKind RK);
374+
/// Returns the llvm.vector.reduce min/max intrinsic that corresponds to the
375+
/// intrinsic op.
376+
LLVM_ABI Intrinsic::ID getMinMaxReductionIntrinsicID(Intrinsic::ID IID);
374377

375378
/// Returns the arithmetic instruction opcode used when expanding a reduction.
376379
LLVM_ABI unsigned getArithmeticReductionInstruction(Intrinsic::ID RdxID);

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,21 @@ constexpr Intrinsic::ID llvm::getReductionIntrinsicID(RecurKind RK) {
956956
}
957957
}
958958

959+
Intrinsic::ID llvm::getMinMaxReductionIntrinsicID(Intrinsic::ID IID) {
960+
switch (IID) {
961+
default:
962+
llvm_unreachable("Unexpected intrinsic id");
963+
case Intrinsic::umin:
964+
return Intrinsic::vector_reduce_umin;
965+
case Intrinsic::umax:
966+
return Intrinsic::vector_reduce_umax;
967+
case Intrinsic::smin:
968+
return Intrinsic::vector_reduce_smin;
969+
case Intrinsic::smax:
970+
return Intrinsic::vector_reduce_smax;
971+
}
972+
}
973+
959974
// This is the inverse to getReductionForBinop
960975
unsigned llvm::getArithmeticReductionInstruction(Intrinsic::ID RdxID) {
961976
switch (RdxID) {

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ class VectorCombine {
135135
bool foldShuffleOfIntrinsics(Instruction &I);
136136
bool foldShuffleToIdentity(Instruction &I);
137137
bool foldShuffleFromReductions(Instruction &I);
138+
bool foldShuffleChainsToReduce(Instruction &I);
138139
bool foldCastFromReductions(Instruction &I);
139140
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
140141
bool foldInterleaveIntrinsics(Instruction &I);
@@ -3136,6 +3137,267 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
31363137
return MadeChanges;
31373138
}
31383139

3140+
/// For a given chain of patterns of the following form:
3141+
///
3142+
/// ```
3143+
/// %1 = shufflevector <n x ty1> %0, <n x ty1> poison <n x ty2> mask
3144+
///
3145+
/// %2 = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %0, <n x
3146+
/// ty1> %1)
3147+
/// OR
3148+
/// %2 = add/mul/or/and/xor <n x ty1> %0, %1
3149+
///
3150+
/// %3 = shufflevector <n x ty1> %2, <n x ty1> poison <n x ty2> mask
3151+
/// ...
3152+
/// ...
3153+
/// %(i - 1) = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %(i -
3154+
/// 3), <n x ty1> %(i - 2)
3155+
/// OR
3156+
/// %(i - 1) = add/mul/or/and/xor <n x ty1> %(i - 3), %(i - 2)
3157+
///
3158+
/// %(i) = extractelement <n x ty1> %(i - 1), 0
3159+
/// ```
3160+
///
3161+
/// Where:
3162+
/// `mask` follows a partition pattern:
3163+
///
3164+
/// Ex:
3165+
/// [n = 8, p = poison]
3166+
///
3167+
/// 4 5 6 7 | p p p p
3168+
/// 2 3 | p p p p p p
3169+
/// 1 | p p p p p p p
3170+
///
3171+
/// For powers of 2, there's a consistent pattern, but for other cases
3172+
/// the parity of the current half value at each step decides the
3173+
/// next partition half (see `ExpectedParityMask` for more logical details
3174+
/// in generalising this).
3175+
///
3176+
/// Ex:
3177+
/// [n = 6]
3178+
///
3179+
/// 3 4 5 | p p p
3180+
/// 1 2 | p p p p
3181+
/// 1 | p p p p p
3182+
bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
3183+
// Going bottom-up for the pattern.
3184+
std::queue<Value *> InstWorklist;
3185+
InstructionCost OrigCost = 0;
3186+
3187+
// Common instruction operation after each shuffle op.
3188+
std::optional<unsigned int> CommonCallOp = std::nullopt;
3189+
std::optional<Instruction::BinaryOps> CommonBinOp = std::nullopt;
3190+
3191+
bool IsFirstCallOrBinInst = true;
3192+
bool ShouldBeCallOrBinInst = true;
3193+
3194+
// This stores the last used instructions for shuffle/common op.
3195+
//
3196+
// PrevVecV[0] / PrevVecV[1] store the last two simultaneous
3197+
// instructions from either shuffle/common op.
3198+
SmallVector<Value *, 2> PrevVecV(2, nullptr);
3199+
3200+
Value *VecOpEE;
3201+
if (!match(&I, m_ExtractElt(m_Value(VecOpEE), m_Zero())))
3202+
return false;
3203+
3204+
auto *FVT = dyn_cast<FixedVectorType>(VecOpEE->getType());
3205+
if (!FVT)
3206+
return false;
3207+
3208+
int64_t VecSize = FVT->getNumElements();
3209+
if (VecSize < 2)
3210+
return false;
3211+
3212+
// Number of levels would be ~log2(n), considering we always partition
3213+
// by half for this fold pattern.
3214+
unsigned int NumLevels = Log2_64_Ceil(VecSize), VisitedCnt = 0;
3215+
int64_t ShuffleMaskHalf = 1, ExpectedParityMask = 0;
3216+
3217+
// This is how we generalise for all element sizes.
3218+
// At each step, if vector size is odd, we need non-poison
3219+
// values to cover the dominant half so we don't miss out on any element.
3220+
//
3221+
// This mask will help us retrieve this as we go from bottom to top:
3222+
//
3223+
// Mask Set -> N = N * 2 - 1
3224+
// Mask Unset -> N = N * 2
3225+
for (int Cur = VecSize, Mask = NumLevels - 1; Cur > 1;
3226+
Cur = (Cur + 1) / 2, --Mask) {
3227+
if (Cur & 1)
3228+
ExpectedParityMask |= (1ll << Mask);
3229+
}
3230+
3231+
InstWorklist.push(VecOpEE);
3232+
3233+
while (!InstWorklist.empty()) {
3234+
Value *CI = InstWorklist.front();
3235+
InstWorklist.pop();
3236+
3237+
if (auto *II = dyn_cast<IntrinsicInst>(CI)) {
3238+
if (!ShouldBeCallOrBinInst)
3239+
return false;
3240+
3241+
if (!IsFirstCallOrBinInst &&
3242+
any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
3243+
return false;
3244+
3245+
// For the first found call/bin op, the vector has to come from the
3246+
// extract element op.
3247+
if (II != (IsFirstCallOrBinInst ? VecOpEE : PrevVecV[0]))
3248+
return false;
3249+
IsFirstCallOrBinInst = false;
3250+
3251+
if (!CommonCallOp)
3252+
CommonCallOp = II->getIntrinsicID();
3253+
if (II->getIntrinsicID() != *CommonCallOp)
3254+
return false;
3255+
3256+
switch (II->getIntrinsicID()) {
3257+
case Intrinsic::umin:
3258+
case Intrinsic::umax:
3259+
case Intrinsic::smin:
3260+
case Intrinsic::smax: {
3261+
auto *Op0 = II->getOperand(0);
3262+
auto *Op1 = II->getOperand(1);
3263+
PrevVecV[0] = Op0;
3264+
PrevVecV[1] = Op1;
3265+
break;
3266+
}
3267+
default:
3268+
return false;
3269+
}
3270+
ShouldBeCallOrBinInst ^= 1;
3271+
3272+
IntrinsicCostAttributes ICA(
3273+
*CommonCallOp, II->getType(),
3274+
{PrevVecV[0]->getType(), PrevVecV[1]->getType()});
3275+
OrigCost += TTI.getIntrinsicInstrCost(ICA, CostKind);
3276+
3277+
// We may need a swap here since it can be (a, b) or (b, a)
3278+
// and accordingly change as we go up.
3279+
if (!isa<ShuffleVectorInst>(PrevVecV[1]))
3280+
std::swap(PrevVecV[0], PrevVecV[1]);
3281+
InstWorklist.push(PrevVecV[1]);
3282+
InstWorklist.push(PrevVecV[0]);
3283+
} else if (auto *BinOp = dyn_cast<BinaryOperator>(CI)) {
3284+
// Similar logic for bin ops.
3285+
3286+
if (!ShouldBeCallOrBinInst)
3287+
return false;
3288+
3289+
if (!IsFirstCallOrBinInst &&
3290+
any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
3291+
return false;
3292+
3293+
if (BinOp != (IsFirstCallOrBinInst ? VecOpEE : PrevVecV[0]))
3294+
return false;
3295+
IsFirstCallOrBinInst = false;
3296+
3297+
if (!CommonBinOp)
3298+
CommonBinOp = BinOp->getOpcode();
3299+
3300+
if (BinOp->getOpcode() != *CommonBinOp)
3301+
return false;
3302+
3303+
switch (*CommonBinOp) {
3304+
case BinaryOperator::Add:
3305+
case BinaryOperator::Mul:
3306+
case BinaryOperator::Or:
3307+
case BinaryOperator::And:
3308+
case BinaryOperator::Xor: {
3309+
auto *Op0 = BinOp->getOperand(0);
3310+
auto *Op1 = BinOp->getOperand(1);
3311+
PrevVecV[0] = Op0;
3312+
PrevVecV[1] = Op1;
3313+
break;
3314+
}
3315+
default:
3316+
return false;
3317+
}
3318+
ShouldBeCallOrBinInst ^= 1;
3319+
3320+
OrigCost +=
3321+
TTI.getArithmeticInstrCost(*CommonBinOp, BinOp->getType(), CostKind);
3322+
3323+
if (!isa<ShuffleVectorInst>(PrevVecV[1]))
3324+
std::swap(PrevVecV[0], PrevVecV[1]);
3325+
InstWorklist.push(PrevVecV[1]);
3326+
InstWorklist.push(PrevVecV[0]);
3327+
} else if (auto *SVInst = dyn_cast<ShuffleVectorInst>(CI)) {
3328+
// We shouldn't have any null values in the previous vectors,
3329+
// is so, there was a mismatch in pattern.
3330+
if (ShouldBeCallOrBinInst ||
3331+
any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
3332+
return false;
3333+
3334+
if (SVInst != PrevVecV[1])
3335+
return false;
3336+
3337+
ArrayRef<int> CurMask;
3338+
if (!match(SVInst, m_Shuffle(m_Specific(PrevVecV[0]), m_Poison(),
3339+
m_Mask(CurMask))))
3340+
return false;
3341+
3342+
// Subtract the parity mask when checking the condition.
3343+
for (int Mask = 0, MaskSize = CurMask.size(); Mask != MaskSize; ++Mask) {
3344+
if (Mask < ShuffleMaskHalf &&
3345+
CurMask[Mask] != ShuffleMaskHalf + Mask - (ExpectedParityMask & 1))
3346+
return false;
3347+
if (Mask >= ShuffleMaskHalf && CurMask[Mask] != -1)
3348+
return false;
3349+
}
3350+
3351+
// Update mask values.
3352+
ShuffleMaskHalf *= 2;
3353+
ShuffleMaskHalf -= (ExpectedParityMask & 1);
3354+
ExpectedParityMask >>= 1;
3355+
3356+
OrigCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
3357+
SVInst->getType(), SVInst->getType(),
3358+
CurMask, CostKind);
3359+
3360+
VisitedCnt += 1;
3361+
if (!ExpectedParityMask && VisitedCnt == NumLevels)
3362+
break;
3363+
3364+
ShouldBeCallOrBinInst ^= 1;
3365+
} else {
3366+
return false;
3367+
}
3368+
}
3369+
3370+
// Pattern should end with a shuffle op.
3371+
if (ShouldBeCallOrBinInst)
3372+
return false;
3373+
3374+
assert(VecSize != -1 && "Expected Match for Vector Size");
3375+
3376+
Value *FinalVecV = PrevVecV[0];
3377+
if (!FinalVecV)
3378+
return false;
3379+
3380+
auto *FinalVecVTy = cast<FixedVectorType>(FinalVecV->getType());
3381+
3382+
Intrinsic::ID ReducedOp =
3383+
(CommonCallOp ? getMinMaxReductionIntrinsicID(*CommonCallOp)
3384+
: getReductionForBinop(*CommonBinOp));
3385+
if (!ReducedOp)
3386+
return false;
3387+
3388+
IntrinsicCostAttributes ICA(ReducedOp, FinalVecVTy, {FinalVecV});
3389+
InstructionCost NewCost = TTI.getIntrinsicInstrCost(ICA, CostKind);
3390+
3391+
if (NewCost >= OrigCost)
3392+
return false;
3393+
3394+
auto *ReducedResult =
3395+
Builder.CreateIntrinsic(ReducedOp, {FinalVecV->getType()}, {FinalVecV});
3396+
replaceValue(I, *ReducedResult);
3397+
3398+
return true;
3399+
}
3400+
31393401
/// Determine if its more efficient to fold:
31403402
/// reduce(trunc(x)) -> trunc(reduce(x)).
31413403
/// reduce(sext(x)) -> sext(reduce(x)).
@@ -4223,6 +4485,9 @@ bool VectorCombine::run() {
42234485
if (foldCastFromReductions(I))
42244486
return true;
42254487
break;
4488+
case Instruction::ExtractElement:
4489+
if (foldShuffleChainsToReduce(I))
4490+
return true;
42264491
case Instruction::ICmp:
42274492
case Instruction::FCmp:
42284493
if (foldExtractExtract(I))

0 commit comments

Comments
 (0)