@@ -3047,13 +3047,6 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
3047
3047
// instruction cost.
3048
3048
return 0 ;
3049
3049
case Instruction::Call: {
3050
- if (!isSingleScalar ()) {
3051
- // TODO: Handle remaining call costs here as well.
3052
- if (VF.isScalable ())
3053
- return InstructionCost::getInvalid ();
3054
- break ;
3055
- }
3056
-
3057
3050
auto *CalledFn =
3058
3051
cast<Function>(getOperand (getNumOperands () - 1 )->getLiveInIRValue ());
3059
3052
if (CalledFn->isIntrinsic ())
@@ -3063,7 +3056,42 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
3063
3056
for (VPValue *ArgOp : drop_end (operands ()))
3064
3057
Tys.push_back (Ctx.Types .inferScalarType (ArgOp));
3065
3058
Type *ResultTy = Ctx.Types .inferScalarType (this );
3066
- return Ctx.TTI .getCallInstrCost (CalledFn, ResultTy, Tys, Ctx.CostKind );
3059
+ InstructionCost ScalarCallCost =
3060
+ Ctx.TTI .getCallInstrCost (CalledFn, ResultTy, Tys, Ctx.CostKind );
3061
+ if (isSingleScalar ())
3062
+ return ScalarCallCost;
3063
+
3064
+ if (VF.isScalable ())
3065
+ return InstructionCost::getInvalid ();
3066
+
3067
+ // Compute the cost of scalarizing the result and operands if needed.
3068
+ InstructionCost ScalarizationCost = 0 ;
3069
+ if (VF.isVector ()) {
3070
+ if (!ResultTy->isVoidTy ()) {
3071
+ for (Type *VectorTy :
3072
+ to_vector (getContainedTypes (toVectorizedTy (ResultTy, VF)))) {
3073
+ ScalarizationCost += Ctx.TTI .getScalarizationOverhead (
3074
+ cast<VectorType>(VectorTy), APInt::getAllOnes (VF.getFixedValue ()),
3075
+ /* Insert=*/ true ,
3076
+ /* Extract=*/ false , Ctx.CostKind );
3077
+ }
3078
+ }
3079
+ // Skip operands that do not require extraction/scalarization and do not
3080
+ // incur any overhead.
3081
+ SmallPtrSet<const VPValue *, 4 > UniqueOperands;
3082
+ Tys.clear ();
3083
+ for (auto *Op : drop_end (operands ())) {
3084
+ if (Op->isLiveIn () || isa<VPReplicateRecipe, VPPredInstPHIRecipe>(Op) ||
3085
+ !UniqueOperands.insert (Op).second )
3086
+ continue ;
3087
+ Tys.push_back (toVectorizedTy (Ctx.Types .inferScalarType (Op), VF));
3088
+ }
3089
+ ScalarizationCost +=
3090
+ Ctx.TTI .getOperandsScalarizationOverhead (Tys, Ctx.CostKind );
3091
+ }
3092
+
3093
+ return ScalarCallCost * (isSingleScalar () ? 1 : VF.getFixedValue ()) +
3094
+ ScalarizationCost;
3067
3095
}
3068
3096
case Instruction::Add:
3069
3097
case Instruction::Sub:
0 commit comments