@@ -3990,6 +3990,27 @@ InstructionCost AArch64TTIImpl::getScalarizationOverhead(
39903990 return DemandedElts.popcount () * (Insert + Extract) * VecInstCost;
39913991}
39923992
3993+ std::optional<InstructionCost> AArch64TTIImpl::getFP16BF16PromoteCost (
3994+ Type *Ty, TTI::TargetCostKind CostKind, TTI::OperandValueInfo Op1Info,
3995+ TTI::OperandValueInfo Op2Info, bool IncludeTrunc,
3996+ std::function<InstructionCost(Type *)> InstCost) const {
3997+ if (!Ty->getScalarType ()->isHalfTy () && !Ty->getScalarType ()->isBFloatTy ())
3998+ return std::nullopt ;
3999+ if (Ty->getScalarType ()->isHalfTy () && ST->hasFullFP16 ())
4000+ return std::nullopt ;
4001+
4002+ Type *PromotedTy = Ty->getWithNewType (Type::getFloatTy (Ty->getContext ()));
4003+ InstructionCost Cost = getCastInstrCost (Instruction::FPExt, PromotedTy, Ty,
4004+ TTI::CastContextHint::None, CostKind);
4005+ if (!Op1Info.isConstant () && !Op2Info.isConstant ())
4006+ Cost *= 2 ;
4007+ Cost += InstCost (PromotedTy);
4008+ if (IncludeTrunc)
4009+ Cost += getCastInstrCost (Instruction::FPTrunc, Ty, PromotedTy,
4010+ TTI::CastContextHint::None, CostKind);
4011+ return Cost;
4012+ }
4013+
39934014InstructionCost AArch64TTIImpl::getArithmeticInstrCost (
39944015 unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
39954016 TTI::OperandValueInfo Op1Info, TTI::OperandValueInfo Op2Info,
@@ -4012,6 +4033,18 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
40124033 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost (Ty);
40134034 int ISD = TLI->InstructionOpcodeToISD (Opcode);
40144035
4036+ // Increase the cost for half and bfloat types if not architecturally
4037+ // supported.
4038+ if (ISD == ISD::FADD || ISD == ISD::FSUB || ISD == ISD::FMUL ||
4039+ ISD == ISD::FDIV || ISD == ISD::FREM)
4040+ if (auto PromotedCost = getFP16BF16PromoteCost (
4041+ Ty, CostKind, Op1Info, Op2Info, /* IncludeTrunc=*/ true ,
4042+ [&](Type *PromotedTy) {
4043+ return getArithmeticInstrCost (Opcode, PromotedTy, CostKind,
4044+ Op1Info, Op2Info);
4045+ }))
4046+ return *PromotedCost;
4047+
40154048 switch (ISD) {
40164049 default :
40174050 return BaseT::getArithmeticInstrCost (Opcode, Ty, CostKind, Op1Info,
@@ -4280,11 +4313,6 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
42804313 [[fallthrough]];
42814314 case ISD::FADD:
42824315 case ISD::FSUB:
4283- // Increase the cost for half and bfloat types if not architecturally
4284- // supported.
4285- if ((Ty->getScalarType ()->isHalfTy () && !ST->hasFullFP16 ()) ||
4286- (Ty->getScalarType ()->isBFloatTy () && !ST->hasBF16 ()))
4287- return 2 * LT.first ;
42884316 if (!Ty->getScalarType ()->isFP128Ty ())
42894317 return LT.first ;
42904318 [[fallthrough]];
@@ -4386,25 +4414,21 @@ InstructionCost AArch64TTIImpl::getCmpSelInstrCost(
43864414 }
43874415
43884416 if (Opcode == Instruction::FCmp) {
4389- // Without dedicated instructions we promote f16 + bf16 compares to f32.
4390- if ((!ST->hasFullFP16 () && ValTy->getScalarType ()->isHalfTy ()) ||
4391- ValTy->getScalarType ()->isBFloatTy ()) {
4392- Type *PromotedTy =
4393- ValTy->getWithNewType (Type::getFloatTy (ValTy->getContext ()));
4394- InstructionCost Cost =
4395- getCastInstrCost (Instruction::FPExt, PromotedTy, ValTy,
4396- TTI::CastContextHint::None, CostKind);
4397- if (!Op1Info.isConstant () && !Op2Info.isConstant ())
4398- Cost *= 2 ;
4399- Cost += getCmpSelInstrCost (Opcode, PromotedTy, CondTy, VecPred, CostKind,
4400- Op1Info, Op2Info);
4401- if (ValTy->isVectorTy ())
4402- Cost += getCastInstrCost (
4403- Instruction::Trunc, VectorType::getInteger (cast<VectorType>(ValTy)),
4404- VectorType::getInteger (cast<VectorType>(PromotedTy)),
4405- TTI::CastContextHint::None, CostKind);
4406- return Cost;
4407- }
4417+ if (auto Cost = getFP16BF16PromoteCost (
4418+ ValTy, CostKind, Op1Info, Op2Info, /* IncludeTrunc=*/ false ,
4419+ [&](Type *PromotedTy) {
4420+ InstructionCost Cost =
4421+ getCmpSelInstrCost (Opcode, PromotedTy, CondTy, VecPred,
4422+ CostKind, Op1Info, Op2Info);
4423+ if (isa<VectorType>(PromotedTy))
4424+ Cost += getCastInstrCost (
4425+ Instruction::Trunc,
4426+ VectorType::getInteger (cast<VectorType>(ValTy)),
4427+ VectorType::getInteger (cast<VectorType>(PromotedTy)),
4428+ TTI::CastContextHint::None, CostKind);
4429+ return Cost;
4430+ }))
4431+ return *Cost;
44084432
44094433 auto LT = getTypeLegalizationCost (ValTy);
44104434 // Model unknown fp compares as a libcall.
0 commit comments