Skip to content

Commit 45473a9

Browse files
MatzeBronlieb
authored andcommitted
X86: Improve cost model of fp16 conversion (llvm#113195)
Improve cost-modeling for x86 __fp16 conversions so the SLPVectorizer transforms the patterns: - Override `X86TTIImpl::getStoreMinimumVF` to report a minimum VF of 4 (SSE register can hold 4xfloat converted/stored to 4xf16) this is necessary as fp16 stores are neither modeled as trunc-stores nor can we mark direct Xxfp16 stores as legal as we generally expand fp16 operations). - Add missing cost entries to `X86TTIImpl::getCastInstrCost` conversion from/to fp16. Note that conversion from f64 to f16 is not supported by an X86 instruction. Change-Id: I84a8a44795fc5d76cc573884c8c76bd04dfbb24b
1 parent a3deb8a commit 45473a9

File tree

3 files changed

+650
-0
lines changed

3 files changed

+650
-0
lines changed

llvm/lib/Target/X86/X86TargetTransformInfo.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2296,7 +2296,10 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
22962296
{ ISD::FP_EXTEND, MVT::v8f64, MVT::v8f32, { 1, 1, 1, 1 } },
22972297
{ ISD::FP_EXTEND, MVT::v8f64, MVT::v16f32, { 3, 1, 1, 1 } },
22982298
{ ISD::FP_EXTEND, MVT::v16f64, MVT::v16f32, { 4, 1, 1, 1 } }, // 2*vcvtps2pd+vextractf64x4
2299+
{ ISD::FP_EXTEND, MVT::v16f32, MVT::v16f16, { 1, 1, 1, 1 } }, // vcvtph2ps
2300+
{ ISD::FP_EXTEND, MVT::v8f64, MVT::v8f16, { 2, 1, 1, 1 } }, // vcvtph2ps+vcvtps2pd
22992301
{ ISD::FP_ROUND, MVT::v8f32, MVT::v8f64, { 1, 1, 1, 1 } },
2302+
{ ISD::FP_ROUND, MVT::v16f16, MVT::v16f32, { 1, 1, 1, 1 } }, // vcvtps2ph
23002303

23012304
{ ISD::TRUNCATE, MVT::v2i1, MVT::v2i8, { 3, 1, 1, 1 } }, // sext+vpslld+vptestmd
23022305
{ ISD::TRUNCATE, MVT::v4i1, MVT::v4i8, { 3, 1, 1, 1 } }, // sext+vpslld+vptestmd
@@ -2973,6 +2976,17 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
29732976
{ ISD::TRUNCATE, MVT::v4i32, MVT::v2i64, { 1, 1, 1, 1 } }, // PSHUFD
29742977
};
29752978

2979+
static const TypeConversionCostKindTblEntry F16ConversionTbl[] = {
2980+
{ ISD::FP_ROUND, MVT::f16, MVT::f32, { 1, 1, 1, 1 } },
2981+
{ ISD::FP_ROUND, MVT::v8f16, MVT::v8f32, { 1, 1, 1, 1 } },
2982+
{ ISD::FP_ROUND, MVT::v4f16, MVT::v4f32, { 1, 1, 1, 1 } },
2983+
{ ISD::FP_EXTEND, MVT::f32, MVT::f16, { 1, 1, 1, 1 } },
2984+
{ ISD::FP_EXTEND, MVT::f64, MVT::f16, { 2, 1, 1, 1 } }, // vcvtph2ps+vcvtps2pd
2985+
{ ISD::FP_EXTEND, MVT::v8f32, MVT::v8f16, { 1, 1, 1, 1 } },
2986+
{ ISD::FP_EXTEND, MVT::v4f32, MVT::v4f16, { 1, 1, 1, 1 } },
2987+
{ ISD::FP_EXTEND, MVT::v4f64, MVT::v4f16, { 2, 1, 1, 1 } }, // vcvtph2ps+vcvtps2pd
2988+
};
2989+
29762990
// Attempt to map directly to (simple) MVT types to let us match custom entries.
29772991
EVT SrcTy = TLI->getValueType(DL, Src);
29782992
EVT DstTy = TLI->getValueType(DL, Dst);
@@ -3034,6 +3048,13 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
30343048
return *KindCost;
30353049
}
30363050

3051+
if (ST->hasF16C()) {
3052+
if (const auto *Entry = ConvertCostTableLookup(F16ConversionTbl, ISD,
3053+
SimpleDstTy, SimpleSrcTy))
3054+
if (auto KindCost = Entry->Cost[CostKind])
3055+
return *KindCost;
3056+
}
3057+
30373058
if (ST->hasSSE41()) {
30383059
if (const auto *Entry = ConvertCostTableLookup(SSE41ConversionTbl, ISD,
30393060
SimpleDstTy, SimpleSrcTy))
@@ -3114,6 +3135,13 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
31143135
if (auto KindCost = Entry->Cost[CostKind])
31153136
return std::max(LTSrc.first, LTDest.first) * *KindCost;
31163137

3138+
if (ST->hasF16C()) {
3139+
if (const auto *Entry = ConvertCostTableLookup(F16ConversionTbl, ISD,
3140+
LTDest.second, LTSrc.second))
3141+
if (auto KindCost = Entry->Cost[CostKind])
3142+
return std::max(LTSrc.first, LTDest.first) * *KindCost;
3143+
}
3144+
31173145
if (ST->hasSSE41())
31183146
if (const auto *Entry = ConvertCostTableLookup(SSE41ConversionTbl, ISD,
31193147
LTDest.second, LTSrc.second))
@@ -3153,6 +3181,11 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
31533181
TTI::CastContextHint::None, CostKind);
31543182
}
31553183

3184+
if (ISD == ISD::FP_ROUND && LTDest.second.getScalarType() == MVT::f16) {
3185+
// Conversion requires a libcall.
3186+
return InstructionCost::getInvalid();
3187+
}
3188+
31563189
// TODO: Allow non-throughput costs that aren't binary.
31573190
auto AdjustCost = [&CostKind](InstructionCost Cost,
31583191
InstructionCost N = 1) -> InstructionCost {
@@ -6930,6 +6963,14 @@ bool X86TTIImpl::isVectorShiftByScalarCheap(Type *Ty) const {
69306963
return true;
69316964
}
69326965

6966+
unsigned X86TTIImpl::getStoreMinimumVF(unsigned VF, Type *ScalarMemTy,
6967+
Type *ScalarValTy) const {
6968+
if (ST->hasF16C() && ScalarMemTy->isHalfTy()) {
6969+
return 4;
6970+
}
6971+
return BaseT::getStoreMinimumVF(VF, ScalarMemTy, ScalarValTy);
6972+
}
6973+
69336974
bool X86TTIImpl::isProfitableToSinkOperands(Instruction *I,
69346975
SmallVectorImpl<Use *> &Ops) const {
69356976
using namespace llvm::PatternMatch;

llvm/lib/Target/X86/X86TargetTransformInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,9 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
302302

303303
bool isVectorShiftByScalarCheap(Type *Ty) const;
304304

305+
unsigned getStoreMinimumVF(unsigned VF, Type *ScalarMemTy,
306+
Type *ScalarValTy) const;
307+
305308
private:
306309
bool supportsGather() const;
307310
InstructionCost getGSVectorCost(unsigned Opcode, TTI::TargetCostKind CostKind,

0 commit comments

Comments
 (0)