@@ -2968,13 +2968,6 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
29682968 // instruction cost.
29692969 return 0 ;
29702970 case Instruction::Call: {
2971- if (!isSingleScalar ()) {
2972- // TODO: Handle remaining call costs here as well.
2973- if (VF.isScalable ())
2974- return InstructionCost::getInvalid ();
2975- break ;
2976- }
2977-
29782971 auto *CalledFn =
29792972 cast<Function>(getOperand (getNumOperands () - 1 )->getLiveInIRValue ());
29802973 if (CalledFn->isIntrinsic ())
@@ -2984,7 +2977,42 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
29842977 for (VPValue *ArgOp : drop_end (operands ()))
29852978 Tys.push_back (Ctx.Types .inferScalarType (ArgOp));
29862979 Type *ResultTy = Ctx.Types .inferScalarType (this );
2987- return Ctx.TTI .getCallInstrCost (CalledFn, ResultTy, Tys, Ctx.CostKind );
2980+ InstructionCost ScalarCallCost =
2981+ Ctx.TTI .getCallInstrCost (CalledFn, ResultTy, Tys, Ctx.CostKind );
2982+ if (isSingleScalar ())
2983+ return ScalarCallCost;
2984+
2985+ if (VF.isScalable ())
2986+ return InstructionCost::getInvalid ();
2987+
2988+ // Compute the cost of scalarizing the result and operands if needed.
2989+ InstructionCost ScalarizationCost = 0 ;
2990+ if (VF.isVector ()) {
2991+ if (!ResultTy->isVoidTy ()) {
2992+ for (Type *VectorTy :
2993+ to_vector (getContainedTypes (toVectorizedTy (ResultTy, VF)))) {
2994+ ScalarizationCost += Ctx.TTI .getScalarizationOverhead (
2995+ cast<VectorType>(VectorTy), APInt::getAllOnes (VF.getFixedValue ()),
2996+ /* Insert=*/ true ,
2997+ /* Extract=*/ false , Ctx.CostKind );
2998+ }
2999+ }
3000+ // Skip operands that do not require extraction/scalarization and do not
3001+ // incur any overhead.
3002+ SmallPtrSet<const VPValue *, 4 > UniqueOperands;
3003+ Tys.clear ();
3004+ for (auto *Op : drop_end (operands ())) {
3005+ if (Op->isLiveIn () || isa<VPReplicateRecipe, VPPredInstPHIRecipe>(Op) ||
3006+ !UniqueOperands.insert (Op).second )
3007+ continue ;
3008+ Tys.push_back (toVectorizedTy (Ctx.Types .inferScalarType (Op), VF));
3009+ }
3010+ ScalarizationCost +=
3011+ Ctx.TTI .getOperandsScalarizationOverhead (Tys, Ctx.CostKind );
3012+ }
3013+
3014+ return ScalarCallCost * (isSingleScalar () ? 1 : VF.getFixedValue ()) +
3015+ ScalarizationCost;
29883016 }
29893017 case Instruction::Add:
29903018 case Instruction::Sub:
0 commit comments