@@ -2761,6 +2761,21 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
27612761 return AdjustCost (
27622762 BaseT::getCastInstrCost (Opcode, Dst, Src, CCH, CostKind, I));
27632763
2764+ static const TypeConversionCostTblEntry BF16Tbl[] = {
2765+ {ISD::FP_ROUND, MVT::bf16 , MVT::f32 , 1 }, // bfcvt
2766+ {ISD::FP_ROUND, MVT::bf16 , MVT::f64 , 1 }, // bfcvt
2767+ {ISD::FP_ROUND, MVT::v4bf16, MVT::v4f32, 1 }, // bfcvtn
2768+ {ISD::FP_ROUND, MVT::v8bf16, MVT::v8f32, 2 }, // bfcvtn+bfcvtn2
2769+ {ISD::FP_ROUND, MVT::v2bf16, MVT::v2f64, 2 }, // bfcvtn+fcvtn
2770+ {ISD::FP_ROUND, MVT::v4bf16, MVT::v4f64, 3 }, // fcvtn+fcvtl2+bfcvtn
2771+ {ISD::FP_ROUND, MVT::v8bf16, MVT::v8f64, 6 }, // 2 * fcvtn+fcvtn2+bfcvtn
2772+ };
2773+
2774+ if (ST->hasBF16 ())
2775+ if (const auto *Entry = ConvertCostTableLookup (
2776+ BF16Tbl, ISD, DstTy.getSimpleVT (), SrcTy.getSimpleVT ()))
2777+ return AdjustCost (Entry->Cost );
2778+
27642779 static const TypeConversionCostTblEntry ConversionTbl[] = {
27652780 {ISD::TRUNCATE, MVT::v2i8, MVT::v2i64, 1 }, // xtn
27662781 {ISD::TRUNCATE, MVT::v2i16, MVT::v2i64, 1 }, // xtn
@@ -2848,6 +2863,14 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
28482863 {ISD::FP_EXTEND, MVT::v2f64, MVT::v2f16, 2 }, // fcvtl+fcvtl
28492864 {ISD::FP_EXTEND, MVT::v4f64, MVT::v4f16, 3 }, // fcvtl+fcvtl2+fcvtl
28502865 {ISD::FP_EXTEND, MVT::v8f64, MVT::v8f16, 6 }, // 2 * fcvtl+fcvtl2+fcvtl
2866+ // BF16 (uses shift)
2867+ {ISD::FP_EXTEND, MVT::f32 , MVT::bf16 , 1 }, // shl
2868+ {ISD::FP_EXTEND, MVT::f64 , MVT::bf16 , 2 }, // shl+fcvt
2869+ {ISD::FP_EXTEND, MVT::v4f32, MVT::v4bf16, 1 }, // shll
2870+ {ISD::FP_EXTEND, MVT::v8f32, MVT::v8bf16, 2 }, // shll+shll2
2871+ {ISD::FP_EXTEND, MVT::v2f64, MVT::v2bf16, 2 }, // shll+fcvtl
2872+ {ISD::FP_EXTEND, MVT::v4f64, MVT::v4bf16, 3 }, // shll+fcvtl+fcvtl2
2873+ {ISD::FP_EXTEND, MVT::v8f64, MVT::v8bf16, 6 }, // 2 * shll+fcvtl+fcvtl2
28512874 // FP Ext and trunc
28522875 {ISD::FP_ROUND, MVT::f32 , MVT::f64 , 1 }, // fcvt
28532876 {ISD::FP_ROUND, MVT::v2f32, MVT::v2f64, 1 }, // fcvtn
@@ -2860,6 +2883,15 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
28602883 {ISD::FP_ROUND, MVT::v2f16, MVT::v2f64, 2 }, // fcvtn+fcvtn
28612884 {ISD::FP_ROUND, MVT::v4f16, MVT::v4f64, 3 }, // fcvtn+fcvtn2+fcvtn
28622885 {ISD::FP_ROUND, MVT::v8f16, MVT::v8f64, 6 }, // 2 * fcvtn+fcvtn2+fcvtn
2886+ // BF16 (more complex, with +bf16 is handled above)
2887+ {ISD::FP_ROUND, MVT::bf16 , MVT::f32 , 8 }, // Expansion is ~8 insns
2888+ {ISD::FP_ROUND, MVT::bf16 , MVT::f64 , 9 }, // fcvtn + above
2889+ {ISD::FP_ROUND, MVT::v2bf16, MVT::v2f32, 8 },
2890+ {ISD::FP_ROUND, MVT::v4bf16, MVT::v4f32, 8 },
2891+ {ISD::FP_ROUND, MVT::v8bf16, MVT::v8f32, 15 },
2892+ {ISD::FP_ROUND, MVT::v2bf16, MVT::v2f64, 9 },
2893+ {ISD::FP_ROUND, MVT::v4bf16, MVT::v4f64, 10 },
2894+ {ISD::FP_ROUND, MVT::v8bf16, MVT::v8f64, 19 },
28632895
28642896 // LowerVectorINT_TO_FP:
28652897 {ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 },
0 commit comments