Skip to content

Commit 913f79c

Browse files
committed
[VPlan] Compute cost of replicating calls in VPlan. (NFCI) (llvm#154291)
Implement computing the scalarization overhead for replicating calls in VPlan, matching the legacy cost model. Depends on llvm#154126. PR: llvm#154291 (cherry picked from commit c3470d1)
1 parent a90cbe9 commit 913f79c

File tree

1 file changed

+36
-8
lines changed

1 file changed

+36
-8
lines changed

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2968,13 +2968,6 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
29682968
// instruction cost.
29692969
return 0;
29702970
case Instruction::Call: {
2971-
if (!isSingleScalar()) {
2972-
// TODO: Handle remaining call costs here as well.
2973-
if (VF.isScalable())
2974-
return InstructionCost::getInvalid();
2975-
break;
2976-
}
2977-
29782971
auto *CalledFn =
29792972
cast<Function>(getOperand(getNumOperands() - 1)->getLiveInIRValue());
29802973
if (CalledFn->isIntrinsic())
@@ -2984,7 +2977,42 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
29842977
for (VPValue *ArgOp : drop_end(operands()))
29852978
Tys.push_back(Ctx.Types.inferScalarType(ArgOp));
29862979
Type *ResultTy = Ctx.Types.inferScalarType(this);
2987-
return Ctx.TTI.getCallInstrCost(CalledFn, ResultTy, Tys, Ctx.CostKind);
2980+
InstructionCost ScalarCallCost =
2981+
Ctx.TTI.getCallInstrCost(CalledFn, ResultTy, Tys, Ctx.CostKind);
2982+
if (isSingleScalar())
2983+
return ScalarCallCost;
2984+
2985+
if (VF.isScalable())
2986+
return InstructionCost::getInvalid();
2987+
2988+
// Compute the cost of scalarizing the result and operands if needed.
2989+
InstructionCost ScalarizationCost = 0;
2990+
if (VF.isVector()) {
2991+
if (!ResultTy->isVoidTy()) {
2992+
for (Type *VectorTy :
2993+
to_vector(getContainedTypes(toVectorizedTy(ResultTy, VF)))) {
2994+
ScalarizationCost += Ctx.TTI.getScalarizationOverhead(
2995+
cast<VectorType>(VectorTy), APInt::getAllOnes(VF.getFixedValue()),
2996+
/*Insert=*/true,
2997+
/*Extract=*/false, Ctx.CostKind);
2998+
}
2999+
}
3000+
// Skip operands that do not require extraction/scalarization and do not
3001+
// incur any overhead.
3002+
SmallPtrSet<const VPValue *, 4> UniqueOperands;
3003+
Tys.clear();
3004+
for (auto *Op : drop_end(operands())) {
3005+
if (Op->isLiveIn() || isa<VPReplicateRecipe, VPPredInstPHIRecipe>(Op) ||
3006+
!UniqueOperands.insert(Op).second)
3007+
continue;
3008+
Tys.push_back(toVectorizedTy(Ctx.Types.inferScalarType(Op), VF));
3009+
}
3010+
ScalarizationCost +=
3011+
Ctx.TTI.getOperandsScalarizationOverhead(Tys, Ctx.CostKind);
3012+
}
3013+
3014+
return ScalarCallCost * (isSingleScalar() ? 1 : VF.getFixedValue()) +
3015+
ScalarizationCost;
29883016
}
29893017
case Instruction::Add:
29903018
case Instruction::Sub:

0 commit comments

Comments
 (0)