Skip to content

Commit 30e9cba

Browse files
committed
[VPlan] Move logic to compute scalarization overhead to cost helper(NFC)
Extract the logic to compute the scalarization overhead to a helper for easy re-use in the future.
1 parent 4b82db9 commit 30e9cba

File tree

3 files changed

+39
-27
lines changed

3 files changed

+39
-27
lines changed

llvm/lib/Transforms/Vectorize/VPlan.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1747,3 +1747,33 @@ VPCostContext::getOperandInfo(VPValue *V) const {
17471747

17481748
return TTI::getOperandInfo(V->getLiveInIRValue());
17491749
}
1750+
1751+
InstructionCost VPCostContext::getScalarizationOverhead(
1752+
Type *ResultTy, ArrayRef<const VPValue *> Operands, ElementCount VF) {
1753+
if (VF.isScalar())
1754+
return 0;
1755+
1756+
InstructionCost ScalarizationCost = 0;
1757+
// Compute the cost of scalarizing the result if needed.
1758+
if (!ResultTy->isVoidTy()) {
1759+
for (Type *VectorTy :
1760+
to_vector(getContainedTypes(toVectorizedTy(ResultTy, VF)))) {
1761+
ScalarizationCost += TTI.getScalarizationOverhead(
1762+
cast<VectorType>(VectorTy), APInt::getAllOnes(VF.getFixedValue()),
1763+
/*Insert=*/true,
1764+
/*Extract=*/false, CostKind);
1765+
}
1766+
}
1767+
// Compute the cost of scalarizing the operands, skipping ones that do not
1768+
// require extraction/scalarization and do not incur any overhead.
1769+
SmallPtrSet<const VPValue *, 4> UniqueOperands;
1770+
SmallVector<Type *> Tys;
1771+
for (auto *Op : Operands) {
1772+
if (Op->isLiveIn() || isa<VPReplicateRecipe, VPPredInstPHIRecipe>(Op) ||
1773+
!UniqueOperands.insert(Op).second)
1774+
continue;
1775+
Tys.push_back(toVectorizedTy(Types.inferScalarType(Op), VF));
1776+
}
1777+
return ScalarizationCost +
1778+
TTI.getOperandsScalarizationOverhead(Tys, CostKind);
1779+
}

llvm/lib/Transforms/Vectorize/VPlanHelpers.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,13 @@ struct VPCostContext {
371371
/// legacy cost model for \p VF. Only used to check for additional VPlan
372372
/// simplifications.
373373
bool isLegacyUniformAfterVectorization(Instruction *I, ElementCount VF) const;
374+
375+
/// Estimate the overhead of scalarizing a recipe with result type \p ResultTy
376+
/// and \p Operands with \p VF. This is a convenience wrapper for the
377+
/// type-based getScalarizationOverhead API.
378+
InstructionCost getScalarizationOverhead(Type *ResultTy,
379+
ArrayRef<const VPValue *> Operands,
380+
ElementCount VF);
374381
};
375382

376383
/// This class can be used to assign names to VPValues. For VPValues without

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3132,33 +3132,8 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
31323132
if (VF.isScalable())
31333133
return InstructionCost::getInvalid();
31343134

3135-
// Compute the cost of scalarizing the result and operands if needed.
3136-
InstructionCost ScalarizationCost = 0;
3137-
if (VF.isVector()) {
3138-
if (!ResultTy->isVoidTy()) {
3139-
for (Type *VectorTy :
3140-
to_vector(getContainedTypes(toVectorizedTy(ResultTy, VF)))) {
3141-
ScalarizationCost += Ctx.TTI.getScalarizationOverhead(
3142-
cast<VectorType>(VectorTy), APInt::getAllOnes(VF.getFixedValue()),
3143-
/*Insert=*/true,
3144-
/*Extract=*/false, Ctx.CostKind);
3145-
}
3146-
}
3147-
// Skip operands that do not require extraction/scalarization and do not
3148-
// incur any overhead.
3149-
SmallPtrSet<const VPValue *, 4> UniqueOperands;
3150-
Tys.clear();
3151-
for (auto *Op : ArgOps) {
3152-
if (Op->isLiveIn() || isa<VPReplicateRecipe, VPPredInstPHIRecipe>(Op) ||
3153-
!UniqueOperands.insert(Op).second)
3154-
continue;
3155-
Tys.push_back(toVectorizedTy(Ctx.Types.inferScalarType(Op), VF));
3156-
}
3157-
ScalarizationCost +=
3158-
Ctx.TTI.getOperandsScalarizationOverhead(Tys, Ctx.CostKind);
3159-
}
3160-
3161-
return ScalarCallCost * VF.getFixedValue() + ScalarizationCost;
3135+
return ScalarCallCost * VF.getFixedValue() +
3136+
Ctx.getScalarizationOverhead(ResultTy, ArgOps, VF);
31623137
}
31633138
case Instruction::Add:
31643139
case Instruction::Sub:

0 commit comments

Comments
 (0)