Skip to content

Commit 0c91e97

Browse files
authored
[VectorCombine] Refine cost model and decision logic in foldSelectShuffle (llvm#146694)
After PR llvm#136329, shuffle indices may differ, which can cause the existing cost-based logic to miss optimisation opportunities for binop/shuffle sequences. This patch improves the cost model in foldSelectShuffle to more accurately assess costs, recognising when certain duplicate shuffles do not require actual instructions. Additionally, in break-even cases, this change introduces a check for whether the pattern ultimately feeds into a vector reduction, allowing the transform to proceed when it is likely to be profitable overall.
1 parent c77a2a2 commit 0c91e97

File tree

2 files changed

+147
-27
lines changed

2 files changed

+147
-27
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 120 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3174,6 +3174,55 @@ bool VectorCombine::foldCastFromReductions(Instruction &I) {
31743174
return true;
31753175
}
31763176

3177+
/// Returns true if this ShuffleVectorInst eventually feeds into a
3178+
/// vector reduction intrinsic (e.g., vector_reduce_add) by only following
3179+
/// chains of shuffles and binary operators (in any combination/order).
3180+
/// The search does not go deeper than the given Depth.
3181+
static bool feedsIntoVectorReduction(ShuffleVectorInst *SVI) {
3182+
constexpr unsigned MaxVisited = 32;
3183+
SmallPtrSet<Instruction *, 8> Visited;
3184+
SmallVector<Instruction *, 4> WorkList;
3185+
bool FoundReduction = false;
3186+
3187+
WorkList.push_back(SVI);
3188+
while (!WorkList.empty()) {
3189+
Instruction *I = WorkList.pop_back_val();
3190+
for (User *U : I->users()) {
3191+
auto *UI = cast<Instruction>(U);
3192+
if (!UI || !Visited.insert(UI).second)
3193+
continue;
3194+
if (Visited.size() > MaxVisited)
3195+
return false;
3196+
if (auto *II = dyn_cast<IntrinsicInst>(UI)) {
3197+
// More than one reduction reached
3198+
if (FoundReduction)
3199+
return false;
3200+
switch (II->getIntrinsicID()) {
3201+
case Intrinsic::vector_reduce_add:
3202+
case Intrinsic::vector_reduce_mul:
3203+
case Intrinsic::vector_reduce_and:
3204+
case Intrinsic::vector_reduce_or:
3205+
case Intrinsic::vector_reduce_xor:
3206+
case Intrinsic::vector_reduce_smin:
3207+
case Intrinsic::vector_reduce_smax:
3208+
case Intrinsic::vector_reduce_umin:
3209+
case Intrinsic::vector_reduce_umax:
3210+
FoundReduction = true;
3211+
continue;
3212+
default:
3213+
return false;
3214+
}
3215+
}
3216+
3217+
if (!isa<BinaryOperator>(UI) && !isa<ShuffleVectorInst>(UI))
3218+
return false;
3219+
3220+
WorkList.emplace_back(UI);
3221+
}
3222+
}
3223+
return FoundReduction;
3224+
}
3225+
31773226
/// This method looks for groups of shuffles acting on binops, of the form:
31783227
/// %x = shuffle ...
31793228
/// %y = shuffle ...
@@ -3416,15 +3465,80 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
34163465
TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, VT, VT, Mask, CostKind);
34173466
};
34183467

3468+
unsigned ElementSize = VT->getElementType()->getPrimitiveSizeInBits();
3469+
unsigned MaxVectorSize =
3470+
TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector);
3471+
unsigned MaxElementsInVector = MaxVectorSize / ElementSize;
3472+
// When there are multiple shufflevector operations on the same input,
3473+
// especially when the vector length is larger than the register size,
3474+
// identical shuffle patterns may occur across different groups of elements.
3475+
// To avoid overestimating the cost by counting these repeated shuffles more
3476+
// than once, we only account for unique shuffle patterns. This adjustment
3477+
// prevents inflated costs in the cost model for wide vectors split into
3478+
// several register-sized groups.
3479+
std::set<SmallVector<int, 4>> UniqueShuffles;
3480+
auto AddShuffleMaskAdjustedCost = [&](InstructionCost C, ArrayRef<int> Mask) {
3481+
// Compute the cost for performing the shuffle over the full vector.
3482+
auto ShuffleCost =
3483+
TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, VT, VT, Mask, CostKind);
3484+
unsigned NumFullVectors = Mask.size() / MaxElementsInVector;
3485+
if (NumFullVectors < 2)
3486+
return C + ShuffleCost;
3487+
SmallVector<int, 4> SubShuffle(MaxElementsInVector);
3488+
unsigned NumUniqueGroups = 0;
3489+
unsigned NumGroups = Mask.size() / MaxElementsInVector;
3490+
// For each group of MaxElementsInVector contiguous elements,
3491+
// collect their shuffle pattern and insert into the set of unique patterns.
3492+
for (unsigned I = 0; I < NumFullVectors; ++I) {
3493+
for (unsigned J = 0; J < MaxElementsInVector; ++J)
3494+
SubShuffle[J] = Mask[MaxElementsInVector * I + J];
3495+
if (UniqueShuffles.insert(SubShuffle).second)
3496+
NumUniqueGroups += 1;
3497+
}
3498+
return C + ShuffleCost * NumUniqueGroups / NumGroups;
3499+
};
3500+
auto AddShuffleAdjustedCost = [&](InstructionCost C, Instruction *I) {
3501+
auto *SV = dyn_cast<ShuffleVectorInst>(I);
3502+
if (!SV)
3503+
return C;
3504+
SmallVector<int, 16> Mask;
3505+
SV->getShuffleMask(Mask);
3506+
return AddShuffleMaskAdjustedCost(C, Mask);
3507+
};
3508+
// Check that input consists of ShuffleVectors applied to the same input
3509+
auto AllShufflesHaveSameOperands =
3510+
[](SmallPtrSetImpl<Instruction *> &InputShuffles) {
3511+
if (InputShuffles.size() < 2)
3512+
return false;
3513+
ShuffleVectorInst *FirstSV =
3514+
dyn_cast<ShuffleVectorInst>(*InputShuffles.begin());
3515+
if (!FirstSV)
3516+
return false;
3517+
3518+
Value *In0 = FirstSV->getOperand(0), *In1 = FirstSV->getOperand(1);
3519+
return std::all_of(
3520+
std::next(InputShuffles.begin()), InputShuffles.end(),
3521+
[&](Instruction *I) {
3522+
ShuffleVectorInst *SV = dyn_cast<ShuffleVectorInst>(I);
3523+
return SV && SV->getOperand(0) == In0 && SV->getOperand(1) == In1;
3524+
});
3525+
};
3526+
34193527
// Get the costs of the shuffles + binops before and after with the new
34203528
// shuffle masks.
34213529
InstructionCost CostBefore =
34223530
TTI.getArithmeticInstrCost(Op0->getOpcode(), VT, CostKind) +
34233531
TTI.getArithmeticInstrCost(Op1->getOpcode(), VT, CostKind);
34243532
CostBefore += std::accumulate(Shuffles.begin(), Shuffles.end(),
34253533
InstructionCost(0), AddShuffleCost);
3426-
CostBefore += std::accumulate(InputShuffles.begin(), InputShuffles.end(),
3427-
InstructionCost(0), AddShuffleCost);
3534+
if (AllShufflesHaveSameOperands(InputShuffles)) {
3535+
UniqueShuffles.clear();
3536+
CostBefore += std::accumulate(InputShuffles.begin(), InputShuffles.end(),
3537+
InstructionCost(0), AddShuffleAdjustedCost);
3538+
} else {
3539+
CostBefore += std::accumulate(InputShuffles.begin(), InputShuffles.end(),
3540+
InstructionCost(0), AddShuffleCost);
3541+
}
34283542

34293543
// The new binops will be unused for lanes past the used shuffle lengths.
34303544
// These types attempt to get the correct cost for that from the target.
@@ -3435,8 +3549,9 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
34353549
InstructionCost CostAfter =
34363550
TTI.getArithmeticInstrCost(Op0->getOpcode(), Op0SmallVT, CostKind) +
34373551
TTI.getArithmeticInstrCost(Op1->getOpcode(), Op1SmallVT, CostKind);
3552+
UniqueShuffles.clear();
34383553
CostAfter += std::accumulate(ReconstructMasks.begin(), ReconstructMasks.end(),
3439-
InstructionCost(0), AddShuffleMaskCost);
3554+
InstructionCost(0), AddShuffleMaskAdjustedCost);
34403555
std::set<SmallVector<int>> OutputShuffleMasks({V1A, V1B, V2A, V2B});
34413556
CostAfter +=
34423557
std::accumulate(OutputShuffleMasks.begin(), OutputShuffleMasks.end(),
@@ -3445,7 +3560,8 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
34453560
LLVM_DEBUG(dbgs() << "Found a binop select shuffle pattern: " << I << "\n");
34463561
LLVM_DEBUG(dbgs() << " CostBefore: " << CostBefore
34473562
<< " vs CostAfter: " << CostAfter << "\n");
3448-
if (CostBefore <= CostAfter)
3563+
if (CostBefore < CostAfter ||
3564+
(CostBefore == CostAfter && !feedsIntoVectorReduction(SVI)))
34493565
return false;
34503566

34513567
// The cost model has passed, create the new instructions.

llvm/test/Transforms/PhaseOrdering/AArch64/slpordering.ll

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -80,29 +80,33 @@ define i32 @slpordering(ptr noundef %p1, i32 noundef %ip1, ptr noundef %p2, i32
8080
; CHECK-NEXT: [[TMP47:%.*]] = shufflevector <16 x i32> [[TMP43]], <16 x i32> poison, <16 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
8181
; CHECK-NEXT: [[TMP48:%.*]] = add nsw <16 x i32> [[TMP45]], [[TMP47]]
8282
; CHECK-NEXT: [[TMP49:%.*]] = sub nsw <16 x i32> [[TMP44]], [[TMP46]]
83-
; CHECK-NEXT: [[TMP50:%.*]] = shufflevector <16 x i32> [[TMP48]], <16 x i32> [[TMP49]], <16 x i32> <i32 16, i32 0, i32 17, i32 1, i32 18, i32 2, i32 19, i32 3, i32 20, i32 4, i32 21, i32 5, i32 22, i32 6, i32 23, i32 7>
84-
; CHECK-NEXT: [[TMP51:%.*]] = shufflevector <16 x i32> [[TMP48]], <16 x i32> [[TMP49]], <16 x i32> <i32 17, i32 1, i32 16, i32 0, i32 19, i32 3, i32 18, i32 2, i32 21, i32 5, i32 20, i32 4, i32 23, i32 7, i32 22, i32 6>
85-
; CHECK-NEXT: [[TMP52:%.*]] = add nsw <16 x i32> [[TMP50]], [[TMP51]]
86-
; CHECK-NEXT: [[TMP53:%.*]] = sub nsw <16 x i32> [[TMP50]], [[TMP51]]
87-
; CHECK-NEXT: [[TMP54:%.*]] = shufflevector <16 x i32> [[TMP53]], <16 x i32> [[TMP52]], <16 x i32> <i32 0, i32 1, i32 18, i32 19, i32 4, i32 5, i32 22, i32 23, i32 8, i32 9, i32 26, i32 27, i32 12, i32 13, i32 30, i32 31>
88-
; CHECK-NEXT: [[TMP55:%.*]] = shufflevector <16 x i32> [[TMP54]], <16 x i32> poison, <16 x i32> <i32 4, i32 5, i32 6, i32 7, i32 0, i32 1, i32 2, i32 3, i32 12, i32 13, i32 14, i32 15, i32 8, i32 9, i32 10, i32 11>
89-
; CHECK-NEXT: [[TMP56:%.*]] = sub nsw <16 x i32> [[TMP54]], [[TMP55]]
90-
; CHECK-NEXT: [[TMP57:%.*]] = add nsw <16 x i32> [[TMP54]], [[TMP55]]
91-
; CHECK-NEXT: [[TMP58:%.*]] = shufflevector <16 x i32> [[TMP56]], <16 x i32> [[TMP57]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 20, i32 21, i32 22, i32 23, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
92-
; CHECK-NEXT: [[TMP59:%.*]] = shufflevector <16 x i32> [[TMP56]], <16 x i32> [[TMP57]], <16 x i32> <i32 8, i32 9, i32 10, i32 11, i32 28, i32 29, i32 30, i32 31, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
93-
; CHECK-NEXT: [[TMP60:%.*]] = shufflevector <16 x i32> [[TMP56]], <16 x i32> [[TMP57]], <16 x i32> <i32 8, i32 9, i32 10, i32 11, i32 28, i32 29, i32 30, i32 31, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
94-
; CHECK-NEXT: [[TMP61:%.*]] = shufflevector <16 x i32> [[TMP56]], <16 x i32> [[TMP57]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 20, i32 21, i32 22, i32 23, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
95-
; CHECK-NEXT: [[TMP62:%.*]] = add nsw <16 x i32> [[TMP59]], [[TMP61]]
96-
; CHECK-NEXT: [[TMP63:%.*]] = sub nsw <16 x i32> [[TMP58]], [[TMP60]]
97-
; CHECK-NEXT: [[TMP64:%.*]] = shufflevector <16 x i32> [[TMP62]], <16 x i32> [[TMP63]], <16 x i32> <i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
98-
; CHECK-NEXT: [[TMP65:%.*]] = lshr <16 x i32> [[TMP64]], splat (i32 15)
99-
; CHECK-NEXT: [[TMP66:%.*]] = and <16 x i32> [[TMP65]], splat (i32 65537)
100-
; CHECK-NEXT: [[TMP67:%.*]] = mul nuw <16 x i32> [[TMP66]], splat (i32 65535)
101-
; CHECK-NEXT: [[TMP68:%.*]] = add <16 x i32> [[TMP67]], [[TMP64]]
102-
; CHECK-NEXT: [[TMP69:%.*]] = xor <16 x i32> [[TMP68]], [[TMP67]]
103-
; CHECK-NEXT: [[TMP70:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP69]])
104-
; CHECK-NEXT: [[CONV118:%.*]] = and i32 [[TMP70]], 65535
105-
; CHECK-NEXT: [[SHR:%.*]] = lshr i32 [[TMP70]], 16
83+
; CHECK-NEXT: [[TMP50:%.*]] = shufflevector <16 x i32> [[TMP48]], <16 x i32> [[TMP49]], <16 x i32> <i32 0, i32 2, i32 4, i32 6, i32 16, i32 18, i32 20, i32 22, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
84+
; CHECK-NEXT: [[TMP51:%.*]] = shufflevector <16 x i32> [[TMP48]], <16 x i32> [[TMP49]], <16 x i32> <i32 1, i32 3, i32 5, i32 7, i32 17, i32 19, i32 21, i32 23, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
85+
; CHECK-NEXT: [[TMP52:%.*]] = shufflevector <16 x i32> [[TMP48]], <16 x i32> [[TMP49]], <16 x i32> <i32 1, i32 3, i32 5, i32 7, i32 17, i32 19, i32 21, i32 23, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
86+
; CHECK-NEXT: [[TMP53:%.*]] = shufflevector <16 x i32> [[TMP48]], <16 x i32> [[TMP49]], <16 x i32> <i32 0, i32 2, i32 4, i32 6, i32 16, i32 18, i32 20, i32 22, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
87+
; CHECK-NEXT: [[TMP54:%.*]] = add nsw <16 x i32> [[TMP51]], [[TMP53]]
88+
; CHECK-NEXT: [[TMP55:%.*]] = sub nsw <16 x i32> [[TMP50]], [[TMP52]]
89+
; CHECK-NEXT: [[TMP56:%.*]] = shufflevector <16 x i32> [[TMP54]], <16 x i32> [[TMP55]], <16 x i32> <i32 1, i32 3, i32 5, i32 7, i32 17, i32 19, i32 21, i32 23, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
90+
; CHECK-NEXT: [[TMP57:%.*]] = shufflevector <16 x i32> [[TMP54]], <16 x i32> [[TMP55]], <16 x i32> <i32 0, i32 2, i32 4, i32 6, i32 16, i32 18, i32 20, i32 22, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
91+
; CHECK-NEXT: [[TMP58:%.*]] = shufflevector <16 x i32> [[TMP54]], <16 x i32> [[TMP55]], <16 x i32> <i32 0, i32 2, i32 4, i32 6, i32 16, i32 18, i32 20, i32 22, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
92+
; CHECK-NEXT: [[TMP59:%.*]] = shufflevector <16 x i32> [[TMP54]], <16 x i32> [[TMP55]], <16 x i32> <i32 1, i32 3, i32 5, i32 7, i32 17, i32 19, i32 21, i32 23, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
93+
; CHECK-NEXT: [[TMP60:%.*]] = sub nsw <16 x i32> [[TMP57]], [[TMP59]]
94+
; CHECK-NEXT: [[TMP61:%.*]] = add nsw <16 x i32> [[TMP56]], [[TMP58]]
95+
; CHECK-NEXT: [[TMP62:%.*]] = shufflevector <16 x i32> [[TMP60]], <16 x i32> [[TMP61]], <16 x i32> <i32 0, i32 2, i32 4, i32 6, i32 16, i32 18, i32 20, i32 22, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
96+
; CHECK-NEXT: [[TMP63:%.*]] = shufflevector <16 x i32> [[TMP60]], <16 x i32> [[TMP61]], <16 x i32> <i32 1, i32 3, i32 5, i32 7, i32 17, i32 19, i32 21, i32 23, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
97+
; CHECK-NEXT: [[TMP64:%.*]] = shufflevector <16 x i32> [[TMP60]], <16 x i32> [[TMP61]], <16 x i32> <i32 1, i32 3, i32 5, i32 7, i32 17, i32 19, i32 21, i32 23, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
98+
; CHECK-NEXT: [[TMP65:%.*]] = shufflevector <16 x i32> [[TMP60]], <16 x i32> [[TMP61]], <16 x i32> <i32 0, i32 2, i32 4, i32 6, i32 16, i32 18, i32 20, i32 22, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
99+
; CHECK-NEXT: [[TMP66:%.*]] = add nsw <16 x i32> [[TMP63]], [[TMP65]]
100+
; CHECK-NEXT: [[TMP67:%.*]] = sub nsw <16 x i32> [[TMP62]], [[TMP64]]
101+
; CHECK-NEXT: [[TMP68:%.*]] = shufflevector <16 x i32> [[TMP66]], <16 x i32> [[TMP67]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23>
102+
; CHECK-NEXT: [[TMP69:%.*]] = lshr <16 x i32> [[TMP68]], splat (i32 15)
103+
; CHECK-NEXT: [[TMP70:%.*]] = and <16 x i32> [[TMP69]], splat (i32 65537)
104+
; CHECK-NEXT: [[TMP71:%.*]] = mul nuw <16 x i32> [[TMP70]], splat (i32 65535)
105+
; CHECK-NEXT: [[TMP72:%.*]] = add <16 x i32> [[TMP71]], [[TMP68]]
106+
; CHECK-NEXT: [[TMP73:%.*]] = xor <16 x i32> [[TMP72]], [[TMP71]]
107+
; CHECK-NEXT: [[TMP74:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP73]])
108+
; CHECK-NEXT: [[CONV118:%.*]] = and i32 [[TMP74]], 65535
109+
; CHECK-NEXT: [[SHR:%.*]] = lshr i32 [[TMP74]], 16
106110
; CHECK-NEXT: [[RDD119:%.*]] = add nuw nsw i32 [[CONV118]], [[SHR]]
107111
; CHECK-NEXT: [[SHR120:%.*]] = lshr i32 [[RDD119]], 1
108112
; CHECK-NEXT: ret i32 [[SHR120]]

0 commit comments

Comments
 (0)