@@ -7741,12 +7741,13 @@ SDValue AArch64TargetLowering::LowerFMA(SDValue Op, SelectionDAG &DAG) const {
77417741 SDLoc DL(Op);
77427742
77437743 // Bail early if we're definitely not looking to merge FNEGs into the FMA.
7744- if (!VT.isFixedLengthVector() || OpC.getOpcode() != ISD::FNEG) {
7745- if (VT.isScalableVector() || VT.getScalarType() == MVT::bf16 ||
7746- useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()))
7747- return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
7748- return Op; // Fallback to NEON lowering.
7749- }
7744+ if (VT != MVT::v8f16 && VT != MVT::v4f32 && VT != MVT::v2f64)
7745+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
7746+
7747+ if (OpC.getOpcode() != ISD::FNEG)
7748+ return useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable())
7749+ ? LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED)
7750+ : Op; // Fallback to NEON lowering.
77507751
77517752 // Convert FMA/FNEG nodes to SVE to enable the following patterns:
77527753 // fma(a, b, neg(c)) -> fnmls(a, b, c)
@@ -7755,17 +7756,15 @@ SDValue AArch64TargetLowering::LowerFMA(SDValue Op, SelectionDAG &DAG) const {
77557756 SDValue Pg = getPredicateForVector(DAG, DL, VT);
77567757 EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
77577758
7758- // Reuse `LowerToPredicatedOp` but drop the subsequent `extract_subvector`
7759- OpA = OpA.getOpcode() == ISD::FNEG
7760- ? LowerToPredicatedOp(OpA, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU)
7761- ->getOperand(0)
7762- : convertToScalableVector(DAG, ContainerVT, OpA);
7763- OpB = OpB.getOpcode() == ISD::FNEG
7764- ? LowerToPredicatedOp(OpB, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU)
7765- ->getOperand(0)
7766- : convertToScalableVector(DAG, ContainerVT, OpB);
7767- OpC = LowerToPredicatedOp(OpC, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU)
7768- ->getOperand(0);
7759+ auto ConvertToScalableFnegMt = [&](SDValue Op) {
7760+ if (Op.getOpcode() == ISD::FNEG)
7761+ Op = LowerToPredicatedOp(Op, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU);
7762+ return convertToScalableVector(DAG, ContainerVT, Op);
7763+ };
7764+
7765+ OpA = ConvertToScalableFnegMt(OpA);
7766+ OpB = ConvertToScalableFnegMt(OpB);
7767+ OpC = ConvertToScalableFnegMt(OpC);
77697768
77707769 SDValue ScalableRes =
77717770 DAG.getNode(AArch64ISD::FMA_PRED, DL, ContainerVT, Pg, OpA, OpB, OpC);
0 commit comments