@@ -3990,6 +3990,27 @@ InstructionCost AArch64TTIImpl::getScalarizationOverhead(
3990
3990
return DemandedElts.popcount () * (Insert + Extract) * VecInstCost;
3991
3991
}
3992
3992
3993
+ std::optional<InstructionCost> AArch64TTIImpl::getFP16BF16PromoteCost (
3994
+ Type *Ty, TTI::TargetCostKind CostKind, TTI::OperandValueInfo Op1Info,
3995
+ TTI::OperandValueInfo Op2Info, bool IncludeTrunc,
3996
+ std::function<InstructionCost(Type *)> InstCost) const {
3997
+ if (!Ty->getScalarType ()->isHalfTy () && !Ty->getScalarType ()->isBFloatTy ())
3998
+ return std::nullopt;
3999
+ if (Ty->getScalarType ()->isHalfTy () && ST->hasFullFP16 ())
4000
+ return std::nullopt;
4001
+
4002
+ Type *PromotedTy = Ty->getWithNewType (Type::getFloatTy (Ty->getContext ()));
4003
+ InstructionCost Cost = getCastInstrCost (Instruction::FPExt, PromotedTy, Ty,
4004
+ TTI::CastContextHint::None, CostKind);
4005
+ if (!Op1Info.isConstant () && !Op2Info.isConstant ())
4006
+ Cost *= 2 ;
4007
+ Cost += InstCost (PromotedTy);
4008
+ if (IncludeTrunc)
4009
+ Cost += getCastInstrCost (Instruction::FPTrunc, Ty, PromotedTy,
4010
+ TTI::CastContextHint::None, CostKind);
4011
+ return Cost;
4012
+ }
4013
+
3993
4014
InstructionCost AArch64TTIImpl::getArithmeticInstrCost (
3994
4015
unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
3995
4016
TTI::OperandValueInfo Op1Info, TTI::OperandValueInfo Op2Info,
@@ -4012,6 +4033,18 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
4012
4033
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost (Ty);
4013
4034
int ISD = TLI->InstructionOpcodeToISD (Opcode);
4014
4035
4036
+ // Increase the cost for half and bfloat types if not architecturally
4037
+ // supported.
4038
+ if (ISD == ISD::FADD || ISD == ISD::FSUB || ISD == ISD::FMUL ||
4039
+ ISD == ISD::FDIV || ISD == ISD::FREM)
4040
+ if (auto PromotedCost = getFP16BF16PromoteCost (
4041
+ Ty, CostKind, Op1Info, Op2Info, /* IncludeTrunc=*/ true ,
4042
+ [&](Type *PromotedTy) {
4043
+ return getArithmeticInstrCost (Opcode, PromotedTy, CostKind,
4044
+ Op1Info, Op2Info);
4045
+ }))
4046
+ return *PromotedCost;
4047
+
4015
4048
switch (ISD) {
4016
4049
default :
4017
4050
return BaseT::getArithmeticInstrCost (Opcode, Ty, CostKind, Op1Info,
@@ -4280,11 +4313,6 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
4280
4313
[[fallthrough]];
4281
4314
case ISD::FADD:
4282
4315
case ISD::FSUB:
4283
- // Increase the cost for half and bfloat types if not architecturally
4284
- // supported.
4285
- if ((Ty->getScalarType ()->isHalfTy () && !ST->hasFullFP16 ()) ||
4286
- (Ty->getScalarType ()->isBFloatTy () && !ST->hasBF16 ()))
4287
- return 2 * LT.first ;
4288
4316
if (!Ty->getScalarType ()->isFP128Ty ())
4289
4317
return LT.first ;
4290
4318
[[fallthrough]];
@@ -4386,25 +4414,21 @@ InstructionCost AArch64TTIImpl::getCmpSelInstrCost(
4386
4414
}
4387
4415
4388
4416
if (Opcode == Instruction::FCmp) {
4389
- // Without dedicated instructions we promote f16 + bf16 compares to f32.
4390
- if ((!ST->hasFullFP16 () && ValTy->getScalarType ()->isHalfTy ()) ||
4391
- ValTy->getScalarType ()->isBFloatTy ()) {
4392
- Type *PromotedTy =
4393
- ValTy->getWithNewType (Type::getFloatTy (ValTy->getContext ()));
4394
- InstructionCost Cost =
4395
- getCastInstrCost (Instruction::FPExt, PromotedTy, ValTy,
4396
- TTI::CastContextHint::None, CostKind);
4397
- if (!Op1Info.isConstant () && !Op2Info.isConstant ())
4398
- Cost *= 2 ;
4399
- Cost += getCmpSelInstrCost (Opcode, PromotedTy, CondTy, VecPred, CostKind,
4400
- Op1Info, Op2Info);
4401
- if (ValTy->isVectorTy ())
4402
- Cost += getCastInstrCost (
4403
- Instruction::Trunc, VectorType::getInteger (cast<VectorType>(ValTy)),
4404
- VectorType::getInteger (cast<VectorType>(PromotedTy)),
4405
- TTI::CastContextHint::None, CostKind);
4406
- return Cost;
4407
- }
4417
+ if (auto Cost = getFP16BF16PromoteCost (
4418
+ ValTy, CostKind, Op1Info, Op2Info, /* IncludeTrunc=*/ false ,
4419
+ [&](Type *PromotedTy) {
4420
+ InstructionCost Cost =
4421
+ getCmpSelInstrCost (Opcode, PromotedTy, CondTy, VecPred,
4422
+ CostKind, Op1Info, Op2Info);
4423
+ if (isa<VectorType>(PromotedTy))
4424
+ Cost += getCastInstrCost (
4425
+ Instruction::Trunc,
4426
+ VectorType::getInteger (cast<VectorType>(ValTy)),
4427
+ VectorType::getInteger (cast<VectorType>(PromotedTy)),
4428
+ TTI::CastContextHint::None, CostKind);
4429
+ return Cost;
4430
+ }))
4431
+ return *Cost;
4408
4432
4409
4433
auto LT = getTypeLegalizationCost (ValTy);
4410
4434
// Model unknown fp compares as a libcall.
0 commit comments