@@ -1908,6 +1908,29 @@ InstructionCost RISCVTTIImpl::getArithmeticInstrCost(
19081908 return BaseT::getArithmeticInstrCost (Opcode, Ty, CostKind, Op1Info, Op2Info,
19091909 Args, CxtI);
19101910
1911+ // f16 with zvfhmin and bf16 will be promoted to f32.
1912+ // FIXME: nxv32[b]f16 will be custom lowered and split.
1913+ unsigned ISDOpcode = TLI->InstructionOpcodeToISD (Opcode);
1914+ InstructionCost CastCost = 0 ;
1915+ if ((LT.second .getVectorElementType () == MVT::f16 ||
1916+ LT.second .getVectorElementType () == MVT::bf16 ) &&
1917+ TLI->getOperationAction (ISDOpcode, LT.second ) ==
1918+ TargetLoweringBase::LegalizeAction::Promote) {
1919+ MVT PromotedVT = TLI->getTypeToPromoteTo (ISDOpcode, LT.second );
1920+ Type *PromotedTy = EVT (PromotedVT).getTypeForEVT (Ty->getContext ());
1921+ Type *LegalTy = EVT (LT.second ).getTypeForEVT (Ty->getContext ());
1922+ // Add cost of extending arguments
1923+ CastCost += LT.first * Args.size () *
1924+ getCastInstrCost (Instruction::FPExt, PromotedTy, LegalTy,
1925+ TTI::CastContextHint::None, CostKind);
1926+ // Add cost of truncating result
1927+ CastCost +=
1928+ LT.first * getCastInstrCost (Instruction::FPTrunc, LegalTy, PromotedTy,
1929+ TTI::CastContextHint::None, CostKind);
1930+ // Compute cost of op in promoted type
1931+ LT.second = PromotedVT;
1932+ }
1933+
19111934 auto getConstantMatCost =
19121935 [&](unsigned Operand, TTI::OperandValueInfo OpInfo) -> InstructionCost {
19131936 if (OpInfo.isUniform () && TLI->canSplatOperand (Opcode, Operand))
@@ -1929,7 +1952,7 @@ InstructionCost RISCVTTIImpl::getArithmeticInstrCost(
19291952 ConstantMatCost += getConstantMatCost (1 , Op2Info);
19301953
19311954 unsigned Op;
1932- switch (TLI-> InstructionOpcodeToISD (Opcode) ) {
1955+ switch (ISDOpcode ) {
19331956 case ISD::ADD:
19341957 case ISD::SUB:
19351958 Op = RISCV::VADD_VV;
@@ -1959,11 +1982,9 @@ InstructionCost RISCVTTIImpl::getArithmeticInstrCost(
19591982 break ;
19601983 case ISD::FADD:
19611984 case ISD::FSUB:
1962- // TODO: Address FP16 with VFHMIN
19631985 Op = RISCV::VFADD_VV;
19641986 break ;
19651987 case ISD::FMUL:
1966- // TODO: Address FP16 with VFHMIN
19671988 Op = RISCV::VFMUL_VV;
19681989 break ;
19691990 case ISD::FDIV:
@@ -1975,9 +1996,9 @@ InstructionCost RISCVTTIImpl::getArithmeticInstrCost(
19751996 default :
19761997 // Assuming all other instructions have the same cost until a need arises to
19771998 // differentiate them.
1978- return ConstantMatCost + BaseT::getArithmeticInstrCost (Opcode, Ty, CostKind,
1979- Op1Info, Op2Info,
1980- Args, CxtI);
1999+ return CastCost + ConstantMatCost +
2000+ BaseT::getArithmeticInstrCost (Opcode, Ty, CostKind, Op1Info, Op2Info,
2001+ Args, CxtI);
19812002 }
19822003
19832004 InstructionCost InstrCost = getRISCVInstructionCost (Op, LT.second , CostKind);
@@ -1986,7 +2007,7 @@ InstructionCost RISCVTTIImpl::getArithmeticInstrCost(
19862007 // scalar floating point ops aren't cheaper than their vector equivalents.
19872008 if (Ty->isFPOrFPVectorTy ())
19882009 InstrCost *= 2 ;
1989- return ConstantMatCost + LT.first * InstrCost;
2010+ return CastCost + ConstantMatCost + LT.first * InstrCost;
19902011}
19912012
19922013// TODO: Deduplicate from TargetTransformInfoImplCRTPBase.
0 commit comments