@@ -297,8 +297,8 @@ RISCVTTIImpl::getPopcntSupport(unsigned TyWidth) const {
297
297
InstructionCost RISCVTTIImpl::getPartialReductionCost (
298
298
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
299
299
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
300
- TTI::PartialReductionExtendKind OpBExtend,
301
- std::optional< unsigned > BinOp ) const {
300
+ TTI::PartialReductionExtendKind OpBExtend, std::optional< unsigned > BinOp,
301
+ TTI::TargetCostKind CostKind ) const {
302
302
303
303
// zve32x is broken for partial_reduce_umla, but let's make sure we
304
304
// don't generate them.
@@ -311,9 +311,8 @@ InstructionCost RISCVTTIImpl::getPartialReductionCost(
311
311
Type *Tp = VectorType::get (AccumType, VF.divideCoefficientBy (4 ));
312
312
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost (Tp);
313
313
// Note: Asuming all vqdot* variants are equal cost
314
- // TODO: Thread CostKind through this API
315
- return LT.first * getRISCVInstructionCost (RISCV::VQDOT_VV, LT.second ,
316
- TTI::TCK_RecipThroughput);
314
+ return LT.first *
315
+ getRISCVInstructionCost (RISCV::VQDOT_VV, LT.second , CostKind);
317
316
}
318
317
319
318
bool RISCVTTIImpl::shouldExpandReduction (const IntrinsicInst *II) const {
0 commit comments