@@ -1523,6 +1523,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
15231523
15241524 for (auto VT : {MVT::v16i8, MVT::v8i8, MVT::v4i16, MVT::v2i32})
15251525 setOperationAction(ISD::GET_ACTIVE_LANE_MASK, VT, Custom);
1526+
1527+ for (auto VT : {MVT::v8f16, MVT::v4f32, MVT::v2f64})
1528+ setOperationAction(ISD::FMA, VT, Custom);
15261529 }
15271530
15281531 if (Subtarget->isSVEorStreamingSVEAvailable()) {
@@ -7781,6 +7784,46 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
77817784 return FCVTNT(VT, BottomBF16, Pg, TopF32);
77827785}
77837786
7787+ SDValue AArch64TargetLowering::LowerFMA(SDValue Op, SelectionDAG &DAG) const {
7788+ SDValue OpA = Op->getOperand(0);
7789+ SDValue OpB = Op->getOperand(1);
7790+ SDValue OpC = Op->getOperand(2);
7791+ EVT VT = Op.getValueType();
7792+ SDLoc DL(Op);
7793+
7794+ assert(VT.isVector() && "Scalar fma lowering should be handled by patterns");
7795+
7796+ // Bail early if we're definitely not looking to merge FNEGs into the FMA.
7797+ if (VT != MVT::v8f16 && VT != MVT::v4f32 && VT != MVT::v2f64)
7798+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
7799+
7800+ if (OpC.getOpcode() != ISD::FNEG)
7801+ return useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable())
7802+ ? LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED)
7803+ : Op; // Fallback to NEON lowering.
7804+
7805+ // Convert FMA/FNEG nodes to SVE to enable the following patterns:
7806+ // fma(a, b, neg(c)) -> fnmls(a, b, c)
7807+ // fma(neg(a), b, neg(c)) -> fnmla(a, b, c)
7808+ // fma(a, neg(b), neg(c)) -> fnmla(a, b, c)
7809+ SDValue Pg = getPredicateForVector(DAG, DL, VT);
7810+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
7811+
7812+ auto ConvertToScalableFnegMt = [&](SDValue Op) {
7813+ if (Op.getOpcode() == ISD::FNEG)
7814+ Op = LowerToPredicatedOp(Op, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU);
7815+ return convertToScalableVector(DAG, ContainerVT, Op);
7816+ };
7817+
7818+ OpA = ConvertToScalableFnegMt(OpA);
7819+ OpB = ConvertToScalableFnegMt(OpB);
7820+ OpC = ConvertToScalableFnegMt(OpC);
7821+
7822+ SDValue ScalableRes =
7823+ DAG.getNode(AArch64ISD::FMA_PRED, DL, ContainerVT, Pg, OpA, OpB, OpC);
7824+ return convertFromScalableVector(DAG, VT, ScalableRes);
7825+ }
7826+
77847827SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
77857828 SelectionDAG &DAG) const {
77867829 LLVM_DEBUG(dbgs() << "Custom lowering: ");
@@ -7857,7 +7900,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
78577900 case ISD::FMUL:
78587901 return LowerFMUL(Op, DAG);
78597902 case ISD::FMA:
7860- return LowerToPredicatedOp (Op, DAG, AArch64ISD::FMA_PRED );
7903+ return LowerFMA (Op, DAG);
78617904 case ISD::FDIV:
78627905 return LowerToPredicatedOp(Op, DAG, AArch64ISD::FDIV_PRED);
78637906 case ISD::FNEG:
0 commit comments