Skip to content

Commit c3470d1

Browse files
authored
[VPlan] Compute cost of replicating calls in VPlan. (NFCI) (llvm#154291)
Implement computing the scalarization overhead for replicating calls in VPlan, matching the legacy cost model. Depends on llvm#154126. PR: llvm#154291
1 parent 769d5c2 commit c3470d1

File tree

1 file changed

+36
-8
lines changed

1 file changed

+36
-8
lines changed

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3047,13 +3047,6 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
30473047
// instruction cost.
30483048
return 0;
30493049
case Instruction::Call: {
3050-
if (!isSingleScalar()) {
3051-
// TODO: Handle remaining call costs here as well.
3052-
if (VF.isScalable())
3053-
return InstructionCost::getInvalid();
3054-
break;
3055-
}
3056-
30573050
auto *CalledFn =
30583051
cast<Function>(getOperand(getNumOperands() - 1)->getLiveInIRValue());
30593052
if (CalledFn->isIntrinsic())
@@ -3063,7 +3056,42 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
30633056
for (VPValue *ArgOp : drop_end(operands()))
30643057
Tys.push_back(Ctx.Types.inferScalarType(ArgOp));
30653058
Type *ResultTy = Ctx.Types.inferScalarType(this);
3066-
return Ctx.TTI.getCallInstrCost(CalledFn, ResultTy, Tys, Ctx.CostKind);
3059+
InstructionCost ScalarCallCost =
3060+
Ctx.TTI.getCallInstrCost(CalledFn, ResultTy, Tys, Ctx.CostKind);
3061+
if (isSingleScalar())
3062+
return ScalarCallCost;
3063+
3064+
if (VF.isScalable())
3065+
return InstructionCost::getInvalid();
3066+
3067+
// Compute the cost of scalarizing the result and operands if needed.
3068+
InstructionCost ScalarizationCost = 0;
3069+
if (VF.isVector()) {
3070+
if (!ResultTy->isVoidTy()) {
3071+
for (Type *VectorTy :
3072+
to_vector(getContainedTypes(toVectorizedTy(ResultTy, VF)))) {
3073+
ScalarizationCost += Ctx.TTI.getScalarizationOverhead(
3074+
cast<VectorType>(VectorTy), APInt::getAllOnes(VF.getFixedValue()),
3075+
/*Insert=*/true,
3076+
/*Extract=*/false, Ctx.CostKind);
3077+
}
3078+
}
3079+
// Skip operands that do not require extraction/scalarization and do not
3080+
// incur any overhead.
3081+
SmallPtrSet<const VPValue *, 4> UniqueOperands;
3082+
Tys.clear();
3083+
for (auto *Op : drop_end(operands())) {
3084+
if (Op->isLiveIn() || isa<VPReplicateRecipe, VPPredInstPHIRecipe>(Op) ||
3085+
!UniqueOperands.insert(Op).second)
3086+
continue;
3087+
Tys.push_back(toVectorizedTy(Ctx.Types.inferScalarType(Op), VF));
3088+
}
3089+
ScalarizationCost +=
3090+
Ctx.TTI.getOperandsScalarizationOverhead(Tys, Ctx.CostKind);
3091+
}
3092+
3093+
return ScalarCallCost * (isSingleScalar() ? 1 : VF.getFixedValue()) +
3094+
ScalarizationCost;
30673095
}
30683096
case Instruction::Add:
30693097
case Instruction::Sub:

0 commit comments

Comments
 (0)