Skip to content

Commit 6907ab4

Browse files
committed
[AArch64] Extend costs for fptoi.sat intrinsics.
Most of these bring the costs in line with the code generation. The f16 costs without FullFP16 are usually converted to f32. Extended v2f32->v2f64 vectors similarly use fcvtl + fcvt. As a backup we use the costs similar to the target independent code, which should give a relatively high cost.
1 parent 594d359 commit 6907ab4

File tree

3 files changed

+152
-121
lines changed

3 files changed

+152
-121
lines changed

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -748,22 +748,44 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
748748
// output are the same, or we are using cvt f64->i32 or f32->i64.
749749
if ((LT.second == MVT::f32 || LT.second == MVT::f64 ||
750750
LT.second == MVT::v2f32 || LT.second == MVT::v4f32 ||
751-
LT.second == MVT::v2f64) &&
752-
(LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits() ||
753-
(LT.second == MVT::f64 && MTy == MVT::i32) ||
754-
(LT.second == MVT::f32 && MTy == MVT::i64)))
755-
return LT.first;
756-
// Similarly for fp16 sizes
757-
if (ST->hasFullFP16() &&
758-
((LT.second == MVT::f16 && MTy == MVT::i32) ||
759-
((LT.second == MVT::v4f16 || LT.second == MVT::v8f16) &&
760-
(LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits()))))
751+
LT.second == MVT::v2f64)) {
752+
if ((LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits() ||
753+
(LT.second == MVT::f64 && MTy == MVT::i32) ||
754+
(LT.second == MVT::f32 && MTy == MVT::i64)))
755+
return LT.first;
756+
// Extending vector types v2f32->v2i64, fcvtl*2 + fcvt*2
757+
if (LT.second.getScalarType() == MVT::f32 && MTy.isFixedLengthVector() &&
758+
MTy.getScalarSizeInBits() == 64)
759+
return LT.first * (MTy.getVectorNumElements() > 2 ? 4 : 2);
760+
}
761+
// Similarly for fp16 sizes. Without FullFP16 we generally need to fcvt to
762+
// f32.
763+
if (LT.second.getScalarType() == MVT::f16 && !ST->hasFullFP16())
764+
return LT.first + getIntrinsicInstrCost(
765+
{ICA.getID(),
766+
RetTy,
767+
{ICA.getArgTypes()[0]->getWithNewType(
768+
Type::getFloatTy(RetTy->getContext()))}},
769+
CostKind);
770+
if ((LT.second == MVT::f16 && MTy == MVT::i32) ||
771+
(LT.second == MVT::f16 && MTy == MVT::i64) ||
772+
((LT.second == MVT::v4f16 || LT.second == MVT::v8f16) &&
773+
(LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits())))
761774
return LT.first;
762-
763-
// Otherwise we use a legal convert followed by a min+max
775+
// Extending vector types v8f16->v8i32, fcvtl*2 + fcvt*2
776+
if (LT.second.getScalarType() == MVT::f16 && MTy.isFixedLengthVector() &&
777+
MTy.getScalarSizeInBits() == 32)
778+
return LT.first * (MTy.getVectorNumElements() > 4 ? 4 : 2);
779+
// Extending vector types v8f16->v8i32. These current scalarize but the
780+
// codegen could be better.
781+
if (LT.second.getScalarType() == MVT::f16 && MTy.isFixedLengthVector() &&
782+
MTy.getScalarSizeInBits() == 64)
783+
return MTy.getVectorNumElements() * 3;
784+
785+
// If we can we use a legal convert followed by a min+max
764786
if ((LT.second.getScalarType() == MVT::f32 ||
765787
LT.second.getScalarType() == MVT::f64 ||
766-
(ST->hasFullFP16() && LT.second.getScalarType() == MVT::f16)) &&
788+
LT.second.getScalarType() == MVT::f16) &&
767789
LT.second.getScalarSizeInBits() >= MTy.getScalarSizeInBits()) {
768790
Type *LegalTy =
769791
Type::getIntNTy(RetTy->getContext(), LT.second.getScalarSizeInBits());
@@ -776,9 +798,33 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
776798
IntrinsicCostAttributes Attrs2(IsSigned ? Intrinsic::smax : Intrinsic::umax,
777799
LegalTy, {LegalTy, LegalTy});
778800
Cost += getIntrinsicInstrCost(Attrs2, CostKind);
779-
return LT.first * Cost;
801+
return LT.first * Cost +
802+
((LT.second.getScalarType() != MVT::f16 || ST->hasFullFP16()) ? 0
803+
: 1);
780804
}
781-
break;
805+
// Otherwise we need to follow the default expansion that clamps the value
806+
// using a float min/max with a fcmp+sel for nan handling when signed.
807+
Type *FPTy = ICA.getArgTypes()[0]->getScalarType();
808+
RetTy = RetTy->getScalarType();
809+
if (LT.second.isVector()) {
810+
FPTy = VectorType::get(FPTy, LT.second.getVectorElementCount());
811+
RetTy = VectorType::get(RetTy, LT.second.getVectorElementCount());
812+
}
813+
IntrinsicCostAttributes Attrs1(Intrinsic::minnum, FPTy, {FPTy, FPTy});
814+
InstructionCost Cost = getIntrinsicInstrCost(Attrs1, CostKind);
815+
IntrinsicCostAttributes Attrs2(Intrinsic::maxnum, FPTy, {FPTy, FPTy});
816+
Cost += getIntrinsicInstrCost(Attrs2, CostKind);
817+
Cost +=
818+
getCastInstrCost(IsSigned ? Instruction::FPToSI : Instruction::FPToUI,
819+
RetTy, FPTy, TTI::CastContextHint::None, CostKind);
820+
if (IsSigned) {
821+
Type *CondTy = RetTy->getWithNewBitWidth(1);
822+
Cost += getCmpSelInstrCost(BinaryOperator::FCmp, FPTy, CondTy,
823+
CmpInst::FCMP_UNO, CostKind);
824+
Cost += getCmpSelInstrCost(BinaryOperator::Select, RetTy, CondTy,
825+
CmpInst::FCMP_UNO, CostKind);
826+
}
827+
return LT.first * Cost;
782828
}
783829
case Intrinsic::fshl:
784830
case Intrinsic::fshr: {

0 commit comments

Comments
 (0)