Skip to content

Commit 0b99b56

Browse files
committed
[AArch64] Generalize costing for FP16 instructions
This extracts the code for modelling a fp16 operation as fptrunc(fpop(fpext,fpext) into a new function named getFP16BF16PromoteCost so that it can be reused by the arithmetic instructions. The function takes a lambda to calculate the cost of the operation with the promoted type.
1 parent bd74197 commit 0b99b56

File tree

6 files changed

+185
-128
lines changed

6 files changed

+185
-128
lines changed

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3975,6 +3975,26 @@ InstructionCost AArch64TTIImpl::getScalarizationOverhead(
39753975
return DemandedElts.popcount() * (Insert + Extract) * VecInstCost;
39763976
}
39773977

3978+
std::optional<InstructionCost> AArch64TTIImpl::getFP16BF16PromoteCost(
3979+
Type *Ty, TTI::TargetCostKind CostKind, TTI::OperandValueInfo Op1Info,
3980+
TTI::OperandValueInfo Op2Info, bool IncludeTrunc,
3981+
std::function<InstructionCost(Type *)> InstCost) const {
3982+
if ((ST->hasFullFP16() || !Ty->getScalarType()->isHalfTy()) &&
3983+
!Ty->getScalarType()->isBFloatTy())
3984+
return std::nullopt;
3985+
3986+
Type *PromotedTy = Ty->getWithNewType(Type::getFloatTy(Ty->getContext()));
3987+
InstructionCost Cost = getCastInstrCost(Instruction::FPExt, PromotedTy, Ty,
3988+
TTI::CastContextHint::None, CostKind);
3989+
if (!Op1Info.isConstant() && !Op2Info.isConstant())
3990+
Cost *= 2;
3991+
Cost += InstCost(PromotedTy);
3992+
if (IncludeTrunc)
3993+
Cost += getCastInstrCost(Instruction::FPTrunc, Ty, PromotedTy,
3994+
TTI::CastContextHint::None, CostKind);
3995+
return Cost;
3996+
}
3997+
39783998
InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
39793999
unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
39804000
TTI::OperandValueInfo Op1Info, TTI::OperandValueInfo Op2Info,
@@ -3997,6 +4017,18 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
39974017
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
39984018
int ISD = TLI->InstructionOpcodeToISD(Opcode);
39994019

4020+
// Increase the cost for half and bfloat types if not architecturally
4021+
// supported.
4022+
if (ISD == ISD::FADD || ISD == ISD::FSUB || ISD == ISD::FMUL ||
4023+
ISD == ISD::FDIV || ISD == ISD::FREM)
4024+
if (auto PromotedCost = getFP16BF16PromoteCost(
4025+
Ty, CostKind, Op1Info, Op2Info, /*IncludeTrunc=*/true,
4026+
[&](Type *PromotedTy) {
4027+
return getArithmeticInstrCost(Opcode, PromotedTy, CostKind,
4028+
Op1Info, Op2Info);
4029+
}))
4030+
return *PromotedCost;
4031+
40004032
switch (ISD) {
40014033
default:
40024034
return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
@@ -4265,11 +4297,6 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
42654297
[[fallthrough]];
42664298
case ISD::FADD:
42674299
case ISD::FSUB:
4268-
// Increase the cost for half and bfloat types if not architecturally
4269-
// supported.
4270-
if ((Ty->getScalarType()->isHalfTy() && !ST->hasFullFP16()) ||
4271-
(Ty->getScalarType()->isBFloatTy() && !ST->hasBF16()))
4272-
return 2 * LT.first;
42734300
if (!Ty->getScalarType()->isFP128Ty())
42744301
return LT.first;
42754302
[[fallthrough]];
@@ -4371,25 +4398,21 @@ InstructionCost AArch64TTIImpl::getCmpSelInstrCost(
43714398
}
43724399

43734400
if (Opcode == Instruction::FCmp) {
4374-
// Without dedicated instructions we promote f16 + bf16 compares to f32.
4375-
if ((!ST->hasFullFP16() && ValTy->getScalarType()->isHalfTy()) ||
4376-
ValTy->getScalarType()->isBFloatTy()) {
4377-
Type *PromotedTy =
4378-
ValTy->getWithNewType(Type::getFloatTy(ValTy->getContext()));
4379-
InstructionCost Cost =
4380-
getCastInstrCost(Instruction::FPExt, PromotedTy, ValTy,
4381-
TTI::CastContextHint::None, CostKind);
4382-
if (!Op1Info.isConstant() && !Op2Info.isConstant())
4383-
Cost *= 2;
4384-
Cost += getCmpSelInstrCost(Opcode, PromotedTy, CondTy, VecPred, CostKind,
4385-
Op1Info, Op2Info);
4386-
if (ValTy->isVectorTy())
4387-
Cost += getCastInstrCost(
4388-
Instruction::Trunc, VectorType::getInteger(cast<VectorType>(ValTy)),
4389-
VectorType::getInteger(cast<VectorType>(PromotedTy)),
4390-
TTI::CastContextHint::None, CostKind);
4391-
return Cost;
4392-
}
4401+
if (auto Cost = getFP16BF16PromoteCost(
4402+
ValTy, CostKind, Op1Info, Op2Info, /*IncludeTrunc=*/false,
4403+
[&](Type *PromotedTy) {
4404+
InstructionCost Cost =
4405+
getCmpSelInstrCost(Opcode, PromotedTy, CondTy, VecPred,
4406+
CostKind, Op1Info, Op2Info);
4407+
if (isa<VectorType>(PromotedTy))
4408+
Cost += getCastInstrCost(
4409+
Instruction::Trunc,
4410+
VectorType::getInteger(cast<VectorType>(ValTy)),
4411+
VectorType::getInteger(cast<VectorType>(PromotedTy)),
4412+
TTI::CastContextHint::None, CostKind);
4413+
return Cost;
4414+
}))
4415+
return *Cost;
43934416

43944417
auto LT = getTypeLegalizationCost(ValTy);
43954418
// Model unknown fp compares as a libcall.

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,14 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
435435

436436
bool preferPredicatedReductionSelect() const override { return ST->hasSVE(); }
437437

438+
/// FP16 and BF16 operations are lowered to fptrunc(op(fpext, fpext) if the
439+
/// architecture features are not present.
440+
std::optional<InstructionCost>
441+
getFP16BF16PromoteCost(Type *Ty, TTI::TargetCostKind CostKind,
442+
TTI::OperandValueInfo Op1Info,
443+
TTI::OperandValueInfo Op2Info, bool IncludeTrunc,
444+
std::function<InstructionCost(Type *)> InstCost) const;
445+
438446
InstructionCost
439447
getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
440448
std::optional<FastMathFlags> FMF,

0 commit comments

Comments
 (0)