@@ -1809,17 +1809,20 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
18091809
18101810 if (!Subtarget->hasSVEB16B16() ||
18111811 !Subtarget->isNonStreamingSVEorSME2Available()) {
1812- for (auto Opcode : {ISD::FADD, ISD::FMA, ISD::FMAXIMUM, ISD::FMAXNUM,
1813- ISD::FMINIMUM, ISD::FMINNUM, ISD::FMUL, ISD::FSUB}) {
1814- setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32);
1815- setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32);
1816- setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32);
1817- }
1818-
1819- if (Subtarget->hasBF16() &&
1820- (Subtarget->hasSVE() || Subtarget->hasSME())) {
1821- for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16})
1812+ for (MVT VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
1813+ MVT PromotedVT = VT.changeVectorElementType(MVT::f32);
1814+ setOperationPromotedToType(ISD::FADD, VT, PromotedVT);
1815+ setOperationPromotedToType(ISD::FMA, VT, PromotedVT);
1816+ setOperationPromotedToType(ISD::FMAXIMUM, VT, PromotedVT);
1817+ setOperationPromotedToType(ISD::FMAXNUM, VT, PromotedVT);
1818+ setOperationPromotedToType(ISD::FMINIMUM, VT, PromotedVT);
1819+ setOperationPromotedToType(ISD::FMINNUM, VT, PromotedVT);
1820+ setOperationPromotedToType(ISD::FSUB, VT, PromotedVT);
1821+
1822+ if (Subtarget->hasBF16())
18221823 setOperationAction(ISD::FMUL, VT, Custom);
1824+ else
1825+ setOperationPromotedToType(ISD::FMUL, VT, PromotedVT);
18231826 }
18241827 }
18251828
@@ -7648,40 +7651,46 @@ SDValue AArch64TargetLowering::LowerINIT_TRAMPOLINE(SDValue Op,
76487651}
76497652
76507653SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
7654+ SDLoc DL(Op);
76517655 EVT VT = Op.getValueType();
76527656 auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
76537657 if (VT.getScalarType() != MVT::bf16 ||
7654- !(Subtarget.hasBF16() && (Subtarget.hasSVE() || Subtarget.hasSME())))
7658+ (Subtarget.hasSVEB16B16() &&
7659+ Subtarget.isNonStreamingSVEorSME2Available()))
76557660 return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
76567661
7657- SDLoc DL(Op);
7662+ assert(Subtarget.hasBF16() && "Expected +bf16 for custom FMUL lowering");
7663+
7664+ auto MakeGetIntrinsic = [&](Intrinsic::ID IID) {
7665+ return [&, IID](EVT VT, auto... Ops) {
7666+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
7667+ DAG.getConstant(IID, DL, MVT::i32), Ops...);
7668+ };
7669+ };
7670+
7671+ // Create helpers for building intrinsic calls.
7672+ auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb);
7673+ auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt);
7674+ auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2);
7675+ auto FCVNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2);
7676+
76587677 SDValue Zero = DAG.getConstantFP(0.0, DL, MVT::nxv4f32);
76597678 SDValue LHS = Op.getOperand(0);
76607679 SDValue RHS = Op.getOperand(1);
76617680
7662- auto GetIntrinsic = [&](Intrinsic::ID IID, EVT VT, auto... Ops) {
7663- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
7664- DAG.getConstant(IID, DL, MVT::i32), Ops...);
7665- };
7666-
76677681 SDValue Pg =
7668- getPTrue( DAG, DL, VT == MVT::nxv2bf16 ? MVT::nxv2i1 : MVT::nxv4i1,
7669- AArch64SVEPredPattern::all);
7682+ DAG.getConstant(1 , DL, VT == MVT::nxv2bf16 ? MVT::nxv2i1 : MVT::nxv4i1);
7683+
76707684 // Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom
76717685 // instructions. These result in two f32 vectors, which can be converted back
76727686 // to bf16 with FCVT and FCVNT.
7673- SDValue BottomF32 = GetIntrinsic(Intrinsic::aarch64_sve_bfmlalb, MVT::nxv4f32,
7674- Zero, LHS, RHS);
7675- SDValue BottomBF16 = GetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2, VT,
7676- DAG.getPOISON(VT), Pg, BottomF32);
7677- if (VT == MVT::nxv8bf16) {
7678- // Note: nxv2bf16 and nxv4bf16 only use even lanes.
7679- SDValue TopF32 = GetIntrinsic(Intrinsic::aarch64_sve_bfmlalt, MVT::nxv4f32,
7680- Zero, LHS, RHS);
7681- return GetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2, VT,
7682- BottomBF16, Pg, TopF32);
7683- }
7684- return BottomBF16;
7687+ SDValue BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS);
7688+ SDValue BottomBF16 = FCVT(VT, DAG.getPOISON(VT), Pg, BottomF32);
7689+ // Note: nxv2bf16 and nxv4bf16 only use even lanes.
7690+ if (VT != MVT::nxv8bf16)
7691+ return BottomBF16;
7692+ SDValue TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS);
7693+ return FCVNT(VT, BottomBF16, Pg, TopF32);
76857694}
76867695
76877696SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
0 commit comments