From da5310f46ff103077678313475ffdfe1ecdbfa6e Mon Sep 17 00:00:00 2001 From: Paul Walker Date: Mon, 3 Mar 2025 15:13:00 +0000 Subject: [PATCH] [LLVM][SVE] Lower bfloat extends the same as other types. When I originally wrote the code I went to some effect to ensure we emitted an unpredicated instruction. I now realise there was a simpler way to achive the same result. --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 15 +++------------ llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td | 6 +++--- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 7a471662ea075..9cf361493fddf 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -4503,18 +4503,9 @@ SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op, if (VT.isScalableVector()) { SDValue SrcVal = Op.getOperand(0); - if (SrcVal.getValueType().getScalarType() == MVT::bf16) { - // bf16 and f32 share the same exponent range so the conversion requires - // them to be aligned with the new mantissa bits zero'd. This is just a - // left shift that is best to isel directly. - if (VT == MVT::nxv2f32 || VT == MVT::nxv4f32) - return Op; - - if (VT != MVT::nxv2f64) - return SDValue(); - - // Break other conversions in two with the first part converting to f32 - // and the second using native f32->VT instructions. + if (VT == MVT::nxv2f64 && SrcVal.getValueType() == MVT::nxv2bf16) { + // Break conversion in two with the first part converting to f32 and the + // second using native f32->VT instructions. SDLoc DL(Op); return DAG.getNode(ISD::FP_EXTEND, DL, VT, DAG.getNode(ISD::FP_EXTEND, DL, MVT::nxv2f32, SrcVal)); diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index 4365e573d8b16..ccfbd91735d84 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -345,7 +345,7 @@ def AArch64fclamp : PatFrags<(ops node:$Zd, node:$Zn, node:$Zm), def SDT_AArch64FCVT : SDTypeProfile<1, 3, [ SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisVec<3>, - SDTCVecEltisVT<1,i1> + SDTCVecEltisVT<1,i1>, SDTCisSameNumEltsAs<0,1>, SDTCisSameAs<0,3> ]>; def SDT_AArch64FCVTR : SDTypeProfile<1, 4, [ @@ -2370,9 +2370,9 @@ let Predicates = [HasSVE_or_SME] in { def : Pat<(nxv2f16 (AArch64fcvtr_mt (nxv2i1 (SVEAllActive:$Pg)), nxv2f32:$Zs, (i64 timm0_1), nxv2f16:$Zd)), (FCVT_ZPmZ_StoH_UNDEF ZPR:$Zd, PPR:$Pg, ZPR:$Zs)>; - def : Pat<(nxv4f32 (fpextend nxv4bf16:$op)), + def : Pat<(nxv4f32 (AArch64fcvte_mt (SVEAnyPredicate), nxv4bf16:$op, undef)), (LSL_ZZI_S $op, (i32 16))>; - def : Pat<(nxv2f32 (fpextend nxv2bf16:$op)), + def : Pat<(nxv2f32 (AArch64fcvte_mt (SVEAnyPredicate), nxv2bf16:$op, undef)), (LSL_ZZI_S $op, (i32 16))>; // Signed integer -> Floating-point