Skip to content

Commit 7f1638e

Browse files
authored
[AArch64] Generalize costing for FP16 instructions (#150033)
This extracts the code for modelling a fp16 operation as `fptrunc(fpop(fpext,fpext))` into a new function named getFP16BF16PromoteCost so that it can be reused by the arithmetic instructions. The function takes a lambda to calculate the cost of the operation with the promoted type.
1 parent 83c308f commit 7f1638e

File tree

6 files changed

+186
-128
lines changed

6 files changed

+186
-128
lines changed

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 48 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3990,6 +3990,27 @@ InstructionCost AArch64TTIImpl::getScalarizationOverhead(
39903990
return DemandedElts.popcount() * (Insert + Extract) * VecInstCost;
39913991
}
39923992

3993+
std::optional<InstructionCost> AArch64TTIImpl::getFP16BF16PromoteCost(
3994+
Type *Ty, TTI::TargetCostKind CostKind, TTI::OperandValueInfo Op1Info,
3995+
TTI::OperandValueInfo Op2Info, bool IncludeTrunc,
3996+
std::function<InstructionCost(Type *)> InstCost) const {
3997+
if (!Ty->getScalarType()->isHalfTy() && !Ty->getScalarType()->isBFloatTy())
3998+
return std::nullopt;
3999+
if (Ty->getScalarType()->isHalfTy() && ST->hasFullFP16())
4000+
return std::nullopt;
4001+
4002+
Type *PromotedTy = Ty->getWithNewType(Type::getFloatTy(Ty->getContext()));
4003+
InstructionCost Cost = getCastInstrCost(Instruction::FPExt, PromotedTy, Ty,
4004+
TTI::CastContextHint::None, CostKind);
4005+
if (!Op1Info.isConstant() && !Op2Info.isConstant())
4006+
Cost *= 2;
4007+
Cost += InstCost(PromotedTy);
4008+
if (IncludeTrunc)
4009+
Cost += getCastInstrCost(Instruction::FPTrunc, Ty, PromotedTy,
4010+
TTI::CastContextHint::None, CostKind);
4011+
return Cost;
4012+
}
4013+
39934014
InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
39944015
unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
39954016
TTI::OperandValueInfo Op1Info, TTI::OperandValueInfo Op2Info,
@@ -4012,6 +4033,18 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
40124033
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
40134034
int ISD = TLI->InstructionOpcodeToISD(Opcode);
40144035

4036+
// Increase the cost for half and bfloat types if not architecturally
4037+
// supported.
4038+
if (ISD == ISD::FADD || ISD == ISD::FSUB || ISD == ISD::FMUL ||
4039+
ISD == ISD::FDIV || ISD == ISD::FREM)
4040+
if (auto PromotedCost = getFP16BF16PromoteCost(
4041+
Ty, CostKind, Op1Info, Op2Info, /*IncludeTrunc=*/true,
4042+
[&](Type *PromotedTy) {
4043+
return getArithmeticInstrCost(Opcode, PromotedTy, CostKind,
4044+
Op1Info, Op2Info);
4045+
}))
4046+
return *PromotedCost;
4047+
40154048
switch (ISD) {
40164049
default:
40174050
return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
@@ -4280,11 +4313,6 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
42804313
[[fallthrough]];
42814314
case ISD::FADD:
42824315
case ISD::FSUB:
4283-
// Increase the cost for half and bfloat types if not architecturally
4284-
// supported.
4285-
if ((Ty->getScalarType()->isHalfTy() && !ST->hasFullFP16()) ||
4286-
(Ty->getScalarType()->isBFloatTy() && !ST->hasBF16()))
4287-
return 2 * LT.first;
42884316
if (!Ty->getScalarType()->isFP128Ty())
42894317
return LT.first;
42904318
[[fallthrough]];
@@ -4386,25 +4414,21 @@ InstructionCost AArch64TTIImpl::getCmpSelInstrCost(
43864414
}
43874415

43884416
if (Opcode == Instruction::FCmp) {
4389-
// Without dedicated instructions we promote f16 + bf16 compares to f32.
4390-
if ((!ST->hasFullFP16() && ValTy->getScalarType()->isHalfTy()) ||
4391-
ValTy->getScalarType()->isBFloatTy()) {
4392-
Type *PromotedTy =
4393-
ValTy->getWithNewType(Type::getFloatTy(ValTy->getContext()));
4394-
InstructionCost Cost =
4395-
getCastInstrCost(Instruction::FPExt, PromotedTy, ValTy,
4396-
TTI::CastContextHint::None, CostKind);
4397-
if (!Op1Info.isConstant() && !Op2Info.isConstant())
4398-
Cost *= 2;
4399-
Cost += getCmpSelInstrCost(Opcode, PromotedTy, CondTy, VecPred, CostKind,
4400-
Op1Info, Op2Info);
4401-
if (ValTy->isVectorTy())
4402-
Cost += getCastInstrCost(
4403-
Instruction::Trunc, VectorType::getInteger(cast<VectorType>(ValTy)),
4404-
VectorType::getInteger(cast<VectorType>(PromotedTy)),
4405-
TTI::CastContextHint::None, CostKind);
4406-
return Cost;
4407-
}
4417+
if (auto Cost = getFP16BF16PromoteCost(
4418+
ValTy, CostKind, Op1Info, Op2Info, /*IncludeTrunc=*/false,
4419+
[&](Type *PromotedTy) {
4420+
InstructionCost Cost =
4421+
getCmpSelInstrCost(Opcode, PromotedTy, CondTy, VecPred,
4422+
CostKind, Op1Info, Op2Info);
4423+
if (isa<VectorType>(PromotedTy))
4424+
Cost += getCastInstrCost(
4425+
Instruction::Trunc,
4426+
VectorType::getInteger(cast<VectorType>(ValTy)),
4427+
VectorType::getInteger(cast<VectorType>(PromotedTy)),
4428+
TTI::CastContextHint::None, CostKind);
4429+
return Cost;
4430+
}))
4431+
return *Cost;
44084432

44094433
auto LT = getTypeLegalizationCost(ValTy);
44104434
// Model unknown fp compares as a libcall.

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,14 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
435435

436436
bool preferPredicatedReductionSelect() const override { return ST->hasSVE(); }
437437

438+
/// FP16 and BF16 operations are lowered to fptrunc(op(fpext, fpext) if the
439+
/// architecture features are not present.
440+
std::optional<InstructionCost>
441+
getFP16BF16PromoteCost(Type *Ty, TTI::TargetCostKind CostKind,
442+
TTI::OperandValueInfo Op1Info,
443+
TTI::OperandValueInfo Op2Info, bool IncludeTrunc,
444+
std::function<InstructionCost(Type *)> InstCost) const;
445+
438446
InstructionCost
439447
getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
440448
std::optional<FastMathFlags> FMF,

0 commit comments

Comments
 (0)