@@ -297,8 +297,8 @@ RISCVTTIImpl::getPopcntSupport(unsigned TyWidth) const {
297297InstructionCost RISCVTTIImpl::getPartialReductionCost (
298298 unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
299299 ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
300- TTI::PartialReductionExtendKind OpBExtend,
301- std::optional< unsigned > BinOp ) const {
300+ TTI::PartialReductionExtendKind OpBExtend, std::optional< unsigned > BinOp,
301+ TTI::TargetCostKind CostKind ) const {
302302
303303 // zve32x is broken for partial_reduce_umla, but let's make sure we
304304 // don't generate them.
@@ -311,9 +311,8 @@ InstructionCost RISCVTTIImpl::getPartialReductionCost(
311311 Type *Tp = VectorType::get (AccumType, VF.divideCoefficientBy (4 ));
312312 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost (Tp);
313313 // Note: Asuming all vqdot* variants are equal cost
314- // TODO: Thread CostKind through this API
315- return LT.first * getRISCVInstructionCost (RISCV::VQDOT_VV, LT.second ,
316- TTI::TCK_RecipThroughput);
314+ return LT.first *
315+ getRISCVInstructionCost (RISCV::VQDOT_VV, LT.second , CostKind);
317316}
318317
319318bool RISCVTTIImpl::shouldExpandReduction (const IntrinsicInst *II) const {
0 commit comments