@@ -3975,6 +3975,26 @@ InstructionCost AArch64TTIImpl::getScalarizationOverhead(
39753975 return DemandedElts.popcount () * (Insert + Extract) * VecInstCost;
39763976}
39773977
3978+ std::optional<InstructionCost> AArch64TTIImpl::getFP16BF16PromoteCost (
3979+ Type *Ty, TTI::TargetCostKind CostKind, TTI::OperandValueInfo Op1Info,
3980+ TTI::OperandValueInfo Op2Info, bool IncludeTrunc,
3981+ std::function<InstructionCost(Type *)> InstCost) const {
3982+ if ((ST->hasFullFP16 () || !Ty->getScalarType ()->isHalfTy ()) &&
3983+ !Ty->getScalarType ()->isBFloatTy ())
3984+ return std::nullopt ;
3985+
3986+ Type *PromotedTy = Ty->getWithNewType (Type::getFloatTy (Ty->getContext ()));
3987+ InstructionCost Cost = getCastInstrCost (Instruction::FPExt, PromotedTy, Ty,
3988+ TTI::CastContextHint::None, CostKind);
3989+ if (!Op1Info.isConstant () && !Op2Info.isConstant ())
3990+ Cost *= 2 ;
3991+ Cost += InstCost (PromotedTy);
3992+ if (IncludeTrunc)
3993+ Cost += getCastInstrCost (Instruction::FPTrunc, Ty, PromotedTy,
3994+ TTI::CastContextHint::None, CostKind);
3995+ return Cost;
3996+ }
3997+
39783998InstructionCost AArch64TTIImpl::getArithmeticInstrCost (
39793999 unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
39804000 TTI::OperandValueInfo Op1Info, TTI::OperandValueInfo Op2Info,
@@ -3997,6 +4017,18 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
39974017 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost (Ty);
39984018 int ISD = TLI->InstructionOpcodeToISD (Opcode);
39994019
4020+ // Increase the cost for half and bfloat types if not architecturally
4021+ // supported.
4022+ if (ISD == ISD::FADD || ISD == ISD::FSUB || ISD == ISD::FMUL ||
4023+ ISD == ISD::FDIV || ISD == ISD::FREM)
4024+ if (auto PromotedCost = getFP16BF16PromoteCost (
4025+ Ty, CostKind, Op1Info, Op2Info, /* IncludeTrunc=*/ true ,
4026+ [&](Type *PromotedTy) {
4027+ return getArithmeticInstrCost (Opcode, PromotedTy, CostKind,
4028+ Op1Info, Op2Info);
4029+ }))
4030+ return *PromotedCost;
4031+
40004032 switch (ISD) {
40014033 default :
40024034 return BaseT::getArithmeticInstrCost (Opcode, Ty, CostKind, Op1Info,
@@ -4265,11 +4297,6 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
42654297 [[fallthrough]];
42664298 case ISD::FADD:
42674299 case ISD::FSUB:
4268- // Increase the cost for half and bfloat types if not architecturally
4269- // supported.
4270- if ((Ty->getScalarType ()->isHalfTy () && !ST->hasFullFP16 ()) ||
4271- (Ty->getScalarType ()->isBFloatTy () && !ST->hasBF16 ()))
4272- return 2 * LT.first ;
42734300 if (!Ty->getScalarType ()->isFP128Ty ())
42744301 return LT.first ;
42754302 [[fallthrough]];
@@ -4371,25 +4398,21 @@ InstructionCost AArch64TTIImpl::getCmpSelInstrCost(
43714398 }
43724399
43734400 if (Opcode == Instruction::FCmp) {
4374- // Without dedicated instructions we promote f16 + bf16 compares to f32.
4375- if ((!ST->hasFullFP16 () && ValTy->getScalarType ()->isHalfTy ()) ||
4376- ValTy->getScalarType ()->isBFloatTy ()) {
4377- Type *PromotedTy =
4378- ValTy->getWithNewType (Type::getFloatTy (ValTy->getContext ()));
4379- InstructionCost Cost =
4380- getCastInstrCost (Instruction::FPExt, PromotedTy, ValTy,
4381- TTI::CastContextHint::None, CostKind);
4382- if (!Op1Info.isConstant () && !Op2Info.isConstant ())
4383- Cost *= 2 ;
4384- Cost += getCmpSelInstrCost (Opcode, PromotedTy, CondTy, VecPred, CostKind,
4385- Op1Info, Op2Info);
4386- if (ValTy->isVectorTy ())
4387- Cost += getCastInstrCost (
4388- Instruction::Trunc, VectorType::getInteger (cast<VectorType>(ValTy)),
4389- VectorType::getInteger (cast<VectorType>(PromotedTy)),
4390- TTI::CastContextHint::None, CostKind);
4391- return Cost;
4392- }
4401+ if (auto Cost = getFP16BF16PromoteCost (
4402+ ValTy, CostKind, Op1Info, Op2Info, /* IncludeTrunc=*/ false ,
4403+ [&](Type *PromotedTy) {
4404+ InstructionCost Cost =
4405+ getCmpSelInstrCost (Opcode, PromotedTy, CondTy, VecPred,
4406+ CostKind, Op1Info, Op2Info);
4407+ if (isa<VectorType>(PromotedTy))
4408+ Cost += getCastInstrCost (
4409+ Instruction::Trunc,
4410+ VectorType::getInteger (cast<VectorType>(ValTy)),
4411+ VectorType::getInteger (cast<VectorType>(PromotedTy)),
4412+ TTI::CastContextHint::None, CostKind);
4413+ return Cost;
4414+ }))
4415+ return *Cost;
43934416
43944417 auto LT = getTypeLegalizationCost (ValTy);
43954418 // Model unknown fp compares as a libcall.
0 commit comments