Skip to content

Commit 4f2a529

Browse files
committed
X86: Improve cost model of fp16 conversion
Improve cost-modeling for x86 __fp16 conversions so the SLPVectorizer transforms the patterns: - `setOperationAction` of v4f16, v8f16 and v16f16 to Custom so `TargetTransformInfo::getStoreMinimumVF` reports them as acceptable. - Add missing cost entries to `X86TTIImpl::getCastInstrCost` conversion from/to fp16. Note that conversion from f64 to f16 is not supported by an X86 instruction.
1 parent 8ae39c8 commit 4f2a529

File tree

3 files changed

+632
-0
lines changed

3 files changed

+632
-0
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1714,6 +1714,9 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
17141714
setOperationPromotedToType(Opc, MVT::v8f16, MVT::v8f32);
17151715
setOperationPromotedToType(Opc, MVT::v16f16, MVT::v16f32);
17161716
}
1717+
// trunc+store via vcvtps2ph
1718+
setOperationAction(ISD::STORE, MVT::v4f16, Custom);
1719+
setOperationAction(ISD::STORE, MVT::v8f16, Custom);
17171720
}
17181721

17191722
// This block controls legalization of the mask vector sizes that are
@@ -1784,6 +1787,9 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
17841787

17851788
for (auto VT : { MVT::v1i1, MVT::v2i1, MVT::v4i1, MVT::v8i1 })
17861789
setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
1790+
1791+
// trunc+store via vcvtps2ph
1792+
setOperationAction(ISD::STORE, MVT::v16f16, Custom);
17871793
}
17881794
if (Subtarget.hasDQI() && Subtarget.hasVLX()) {
17891795
for (MVT VT : {MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64}) {

llvm/lib/Target/X86/X86TargetTransformInfo.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2296,7 +2296,10 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
22962296
{ ISD::FP_EXTEND, MVT::v8f64, MVT::v8f32, { 1, 1, 1, 1 } },
22972297
{ ISD::FP_EXTEND, MVT::v8f64, MVT::v16f32, { 3, 1, 1, 1 } },
22982298
{ ISD::FP_EXTEND, MVT::v16f64, MVT::v16f32, { 4, 1, 1, 1 } }, // 2*vcvtps2pd+vextractf64x4
2299+
{ ISD::FP_EXTEND, MVT::v16f32, MVT::v16f16, { 1, 1, 1, 1 } }, // vcvtph2ps
2300+
{ ISD::FP_EXTEND, MVT::v8f64, MVT::v8f16, { 2, 1, 1, 1 } }, // vcvtph2ps+vcvtps2pd
22992301
{ ISD::FP_ROUND, MVT::v8f32, MVT::v8f64, { 1, 1, 1, 1 } },
2302+
{ ISD::FP_ROUND, MVT::v16f16, MVT::v16f32, { 1, 1, 1, 1 } }, // vcvtps2ph
23002303

23012304
{ ISD::TRUNCATE, MVT::v2i1, MVT::v2i8, { 3, 1, 1, 1 } }, // sext+vpslld+vptestmd
23022305
{ ISD::TRUNCATE, MVT::v4i1, MVT::v4i8, { 3, 1, 1, 1 } }, // sext+vpslld+vptestmd
@@ -2973,6 +2976,14 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
29732976
{ ISD::TRUNCATE, MVT::v4i32, MVT::v2i64, { 1, 1, 1, 1 } }, // PSHUFD
29742977
};
29752978

2979+
static const TypeConversionCostKindTblEntry F16ConversionTbl[] = {
2980+
{ ISD::FP_ROUND, MVT::v8f16, MVT::v8f32, { 1, 1, 1, 1 } }, // vcvtps2ph
2981+
{ ISD::FP_ROUND, MVT::v4f16, MVT::v4f32, { 1, 1, 1, 1 } }, // vcvtps2ph
2982+
{ ISD::FP_EXTEND, MVT::v8f32, MVT::v8f16, { 1, 1, 1, 1 } }, // vcvtph2ps
2983+
{ ISD::FP_EXTEND, MVT::v4f32, MVT::v4f16, { 1, 1, 1, 1 } }, // vcvtph2ps
2984+
{ ISD::FP_EXTEND, MVT::v4f64, MVT::v4f16, { 2, 1, 1, 1 } }, // vcvtph2ps+vcvtps2pd
2985+
};
2986+
29762987
// Attempt to map directly to (simple) MVT types to let us match custom entries.
29772988
EVT SrcTy = TLI->getValueType(DL, Src);
29782989
EVT DstTy = TLI->getValueType(DL, Dst);
@@ -3034,6 +3045,13 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
30343045
return *KindCost;
30353046
}
30363047

3048+
if (ST->hasF16C()) {
3049+
if (const auto *Entry = ConvertCostTableLookup(F16ConversionTbl, ISD,
3050+
SimpleDstTy, SimpleSrcTy))
3051+
if (auto KindCost = Entry->Cost[CostKind])
3052+
return *KindCost;
3053+
}
3054+
30373055
if (ST->hasSSE41()) {
30383056
if (const auto *Entry = ConvertCostTableLookup(SSE41ConversionTbl, ISD,
30393057
SimpleDstTy, SimpleSrcTy))
@@ -3107,6 +3125,13 @@ InstructionCost X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
31073125
if (auto KindCost = Entry->Cost[CostKind])
31083126
return std::max(LTSrc.first, LTDest.first) * *KindCost;
31093127

3128+
if (ST->hasF16C()) {
3129+
if (const auto *Entry = ConvertCostTableLookup(F16ConversionTbl, ISD,
3130+
LTDest.second, LTSrc.second))
3131+
if (auto KindCost = Entry->Cost[CostKind])
3132+
return std::max(LTSrc.first, LTDest.first) * *KindCost;
3133+
}
3134+
31103135
if (ST->hasSSE41())
31113136
if (const auto *Entry = ConvertCostTableLookup(SSE41ConversionTbl, ISD,
31123137
LTDest.second, LTSrc.second))

0 commit comments

Comments
 (0)