@@ -1797,17 +1797,20 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
17971797
17981798 if (!Subtarget->hasSVEB16B16() ||
17991799 !Subtarget->isNonStreamingSVEorSME2Available()) {
1800- for (auto Opcode : {ISD::FADD, ISD::FMA, ISD::FMAXIMUM, ISD::FMAXNUM,
1801- ISD::FMINIMUM, ISD::FMINNUM, ISD::FMUL, ISD::FSUB}) {
1802- setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32);
1803- setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32);
1804- setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32);
1805- }
1806-
1807- if (Subtarget->hasBF16() &&
1808- (Subtarget->hasSVE() || Subtarget->hasSME())) {
1809- for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16})
1800+ for (MVT VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
1801+ MVT PromotedVT = VT.changeVectorElementType(MVT::f32);
1802+ setOperationPromotedToType(ISD::FADD, VT, PromotedVT);
1803+ setOperationPromotedToType(ISD::FMA, VT, PromotedVT);
1804+ setOperationPromotedToType(ISD::FMAXIMUM, VT, PromotedVT);
1805+ setOperationPromotedToType(ISD::FMAXNUM, VT, PromotedVT);
1806+ setOperationPromotedToType(ISD::FMINIMUM, VT, PromotedVT);
1807+ setOperationPromotedToType(ISD::FMINNUM, VT, PromotedVT);
1808+ setOperationPromotedToType(ISD::FSUB, VT, PromotedVT);
1809+
1810+ if (Subtarget->hasBF16())
18101811 setOperationAction(ISD::FMUL, VT, Custom);
1812+ else
1813+ setOperationPromotedToType(ISD::FMUL, VT, PromotedVT);
18111814 }
18121815 }
18131816
@@ -7545,40 +7548,46 @@ SDValue AArch64TargetLowering::LowerINIT_TRAMPOLINE(SDValue Op,
75457548}
75467549
75477550SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
7551+ SDLoc DL(Op);
75487552 EVT VT = Op.getValueType();
75497553 auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
75507554 if (VT.getScalarType() != MVT::bf16 ||
7551- !(Subtarget.hasBF16() && (Subtarget.hasSVE() || Subtarget.hasSME())))
7555+ (Subtarget.hasSVEB16B16() &&
7556+ Subtarget.isNonStreamingSVEorSME2Available()))
75527557 return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
75537558
7554- SDLoc DL(Op);
7559+ assert(Subtarget.hasBF16() && "Expected +bf16 for custom FMUL lowering");
7560+
7561+ auto MakeGetIntrinsic = [&](Intrinsic::ID IID) {
7562+ return [&, IID](EVT VT, auto... Ops) {
7563+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
7564+ DAG.getConstant(IID, DL, MVT::i32), Ops...);
7565+ };
7566+ };
7567+
7568+ // Create helpers for building intrinsic calls.
7569+ auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb);
7570+ auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt);
7571+ auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2);
7572+ auto FCVNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2);
7573+
75557574 SDValue Zero = DAG.getConstantFP(0.0, DL, MVT::nxv4f32);
75567575 SDValue LHS = Op.getOperand(0);
75577576 SDValue RHS = Op.getOperand(1);
75587577
7559- auto GetIntrinsic = [&](Intrinsic::ID IID, EVT VT, auto... Ops) {
7560- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
7561- DAG.getConstant(IID, DL, MVT::i32), Ops...);
7562- };
7563-
75647578 SDValue Pg =
7565- getPTrue( DAG, DL, VT == MVT::nxv2bf16 ? MVT::nxv2i1 : MVT::nxv4i1,
7566- AArch64SVEPredPattern::all);
7579+ DAG.getConstant(1 , DL, VT == MVT::nxv2bf16 ? MVT::nxv2i1 : MVT::nxv4i1);
7580+
75677581 // Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom
75687582 // instructions. These result in two f32 vectors, which can be converted back
75697583 // to bf16 with FCVT and FCVNT.
7570- SDValue BottomF32 = GetIntrinsic(Intrinsic::aarch64_sve_bfmlalb, MVT::nxv4f32,
7571- Zero, LHS, RHS);
7572- SDValue BottomBF16 = GetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2, VT,
7573- DAG.getPOISON(VT), Pg, BottomF32);
7574- if (VT == MVT::nxv8bf16) {
7575- // Note: nxv2bf16 and nxv4bf16 only use even lanes.
7576- SDValue TopF32 = GetIntrinsic(Intrinsic::aarch64_sve_bfmlalt, MVT::nxv4f32,
7577- Zero, LHS, RHS);
7578- return GetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2, VT,
7579- BottomBF16, Pg, TopF32);
7580- }
7581- return BottomBF16;
7584+ SDValue BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS);
7585+ SDValue BottomBF16 = FCVT(VT, DAG.getPOISON(VT), Pg, BottomF32);
7586+ // Note: nxv2bf16 and nxv4bf16 only use even lanes.
7587+ if (VT != MVT::nxv8bf16)
7588+ return BottomBF16;
7589+ SDValue TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS);
7590+ return FCVNT(VT, BottomBF16, Pg, TopF32);
75827591}
75837592
75847593SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
0 commit comments