Skip to content

Commit 5267d0f

Browse files
committed
[VPlan] Move logic to compute scalarization overhead to cost helper(NFC)
Extract the logic to compute the scalarization overhead to a helper for easy re-use in the future. (cherry-picked from commit 30e9cba)
1 parent 81ca03e commit 5267d0f

File tree

3 files changed

+39
-27
lines changed

3 files changed

+39
-27
lines changed

llvm/lib/Transforms/Vectorize/VPlan.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1689,3 +1689,33 @@ VPCostContext::getOperandInfo(VPValue *V) const {
16891689

16901690
return TTI::getOperandInfo(V->getLiveInIRValue());
16911691
}
1692+
1693+
InstructionCost VPCostContext::getScalarizationOverhead(
1694+
Type *ResultTy, ArrayRef<const VPValue *> Operands, ElementCount VF) {
1695+
if (VF.isScalar())
1696+
return 0;
1697+
1698+
InstructionCost ScalarizationCost = 0;
1699+
// Compute the cost of scalarizing the result if needed.
1700+
if (!ResultTy->isVoidTy()) {
1701+
for (Type *VectorTy :
1702+
to_vector(getContainedTypes(toVectorizedTy(ResultTy, VF)))) {
1703+
ScalarizationCost += TTI.getScalarizationOverhead(
1704+
cast<VectorType>(VectorTy), APInt::getAllOnes(VF.getFixedValue()),
1705+
/*Insert=*/true,
1706+
/*Extract=*/false, CostKind);
1707+
}
1708+
}
1709+
// Compute the cost of scalarizing the operands, skipping ones that do not
1710+
// require extraction/scalarization and do not incur any overhead.
1711+
SmallPtrSet<const VPValue *, 4> UniqueOperands;
1712+
SmallVector<Type *> Tys;
1713+
for (auto *Op : Operands) {
1714+
if (Op->isLiveIn() || isa<VPReplicateRecipe, VPPredInstPHIRecipe>(Op) ||
1715+
!UniqueOperands.insert(Op).second)
1716+
continue;
1717+
Tys.push_back(toVectorizedTy(Types.inferScalarType(Op), VF));
1718+
}
1719+
return ScalarizationCost +
1720+
TTI.getOperandsScalarizationOverhead(Tys, CostKind);
1721+
}

llvm/lib/Transforms/Vectorize/VPlanHelpers.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,13 @@ struct VPCostContext {
374374
/// legacy cost model for \p VF. Only used to check for additional VPlan
375375
/// simplifications.
376376
bool isLegacyUniformAfterVectorization(Instruction *I, ElementCount VF) const;
377+
378+
/// Estimate the overhead of scalarizing a recipe with result type \p ResultTy
379+
/// and \p Operands with \p VF. This is a convenience wrapper for the
380+
/// type-based getScalarizationOverhead API.
381+
InstructionCost getScalarizationOverhead(Type *ResultTy,
382+
ArrayRef<const VPValue *> Operands,
383+
ElementCount VF);
377384
};
378385

379386
/// This class can be used to assign names to VPValues. For VPValues without

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3024,33 +3024,8 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
30243024
if (VF.isScalable())
30253025
return InstructionCost::getInvalid();
30263026

3027-
// Compute the cost of scalarizing the result and operands if needed.
3028-
InstructionCost ScalarizationCost = 0;
3029-
if (VF.isVector()) {
3030-
if (!ResultTy->isVoidTy()) {
3031-
for (Type *VectorTy :
3032-
to_vector(getContainedTypes(toVectorizedTy(ResultTy, VF)))) {
3033-
ScalarizationCost += Ctx.TTI.getScalarizationOverhead(
3034-
cast<VectorType>(VectorTy), APInt::getAllOnes(VF.getFixedValue()),
3035-
/*Insert=*/true,
3036-
/*Extract=*/false, Ctx.CostKind);
3037-
}
3038-
}
3039-
// Skip operands that do not require extraction/scalarization and do not
3040-
// incur any overhead.
3041-
SmallPtrSet<const VPValue *, 4> UniqueOperands;
3042-
Tys.clear();
3043-
for (auto *Op : ArgOps) {
3044-
if (Op->isLiveIn() || isa<VPReplicateRecipe, VPPredInstPHIRecipe>(Op) ||
3045-
!UniqueOperands.insert(Op).second)
3046-
continue;
3047-
Tys.push_back(toVectorizedTy(Ctx.Types.inferScalarType(Op), VF));
3048-
}
3049-
ScalarizationCost +=
3050-
Ctx.TTI.getOperandsScalarizationOverhead(Tys, Ctx.CostKind);
3051-
}
3052-
3053-
return ScalarCallCost * VF.getFixedValue() + ScalarizationCost;
3027+
return ScalarCallCost * VF.getFixedValue() +
3028+
Ctx.getScalarizationOverhead(ResultTy, ArgOps, VF);
30543029
}
30553030
case Instruction::Add:
30563031
case Instruction::Sub:

0 commit comments

Comments
 (0)