@@ -1935,6 +1935,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
19351935                         Custom);
19361936      setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv2i64,
19371937                         Custom);
1938+ 
1939+       if (EnablePartialReduceNodes) {
1940+         static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
1941+                                           ISD::PARTIAL_REDUCE_UMLA};
1942+         // Must be lowered to SVE instructions.
1943+         setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v4i32, Custom);
1944+         setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v8i16, Custom);
1945+         setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
1946+         setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v8i16, Custom);
1947+         setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Custom);
1948+         setPartialReduceMLAAction(MLAOps, MVT::v8i16, MVT::v16i8, Custom);
1949+       }
19381950    }
19391951  }
19401952
@@ -2230,6 +2242,28 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
22302242  bool PreferNEON = VT.is64BitVector() || VT.is128BitVector();
22312243  bool PreferSVE = !PreferNEON && Subtarget->isSVEAvailable();
22322244
2245+   if (EnablePartialReduceNodes) {
2246+     static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
2247+                                       ISD::PARTIAL_REDUCE_UMLA};
2248+     unsigned NumElts = VT.getVectorNumElements();
2249+     if (VT.getVectorElementType() == MVT::i64) {
2250+       setPartialReduceMLAAction(MLAOps, VT,
2251+                                 MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
2252+       setPartialReduceMLAAction(
2253+           MLAOps, VT, MVT::getVectorVT(MVT::i16, NumElts * 4), Custom);
2254+       setPartialReduceMLAAction(
2255+           MLAOps, VT, MVT::getVectorVT(MVT::i32, NumElts * 2), Custom);
2256+     } else if (VT.getVectorElementType() == MVT::i32) {
2257+       setPartialReduceMLAAction(MLAOps, VT,
2258+                                 MVT::getVectorVT(MVT::i8, NumElts * 4), Custom);
2259+       setPartialReduceMLAAction(
2260+           MLAOps, VT, MVT::getVectorVT(MVT::i16, NumElts * 2), Custom);
2261+     } else if (VT.getVectorElementType() == MVT::i16) {
2262+       setPartialReduceMLAAction(MLAOps, VT,
2263+                                 MVT::getVectorVT(MVT::i8, NumElts * 2), Custom);
2264+     }
2265+   }
2266+ 
22332267  // Lower fixed length vector operations to scalable equivalents.
22342268  setOperationAction(ISD::ABDS, VT, Default);
22352269  setOperationAction(ISD::ABDU, VT, Default);
@@ -29251,50 +29285,61 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
2925129285SDValue
2925229286AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
2925329287                                               SelectionDAG &DAG) const {
29254-   bool Scalable = Op.getValueType().isScalableVector();
29255- 
29256-   assert((!Scalable || Subtarget->isSVEorStreamingSVEAvailable()) &&
29257-          "SVE or StreamingSVE must be available when using scalable vectors.");
29258-   assert((Scalable || Subtarget->hasDotProd()) &&
29259-          "Dotprod must be available when targeting NEON dot product "
29260-          "instructions.");
29261- 
2926229288  SDLoc DL(Op);
2926329289
2926429290  SDValue Acc = Op.getOperand(0);
2926529291  SDValue LHS = Op.getOperand(1);
2926629292  SDValue RHS = Op.getOperand(2);
2926729293  EVT ResultVT = Op.getValueType();
29294+   EVT OrigResultVT = ResultVT;
29295+   EVT OpVT = LHS.getValueType();
2926829296
29269-   assert((Scalable && ResultVT == MVT::nxv2i64 &&
29270-           LHS.getValueType() == MVT::nxv16i8) ||
29271-          (!Scalable && ResultVT == MVT::v2i64 &&
29272-           LHS.getValueType() == MVT::v16i8));
29297+   bool ConvertToScalable =
29298+       ResultVT.isFixedLengthVector() &&
29299+       useSVEForFixedLengthVectorVT(ResultVT, /*OverrideNEON=*/true);
2927329300
29274-   EVT DotVT = Scalable ? MVT::nxv4i32 : MVT::v4i32;
29301+   if (ConvertToScalable) {
29302+     ResultVT = getContainerForFixedLengthVector(DAG, ResultVT);
29303+     OpVT = getContainerForFixedLengthVector(DAG, LHS.getValueType());
29304+     Acc = convertToScalableVector(DAG, ResultVT, Acc);
29305+     LHS = convertToScalableVector(DAG, OpVT, LHS);
29306+     RHS = convertToScalableVector(DAG, OpVT, RHS);
29307+     Op = DAG.getNode(Op.getOpcode(), DL, ResultVT, {Acc, LHS, RHS});
29308+   }
29309+ 
29310+   // Two-way and four-way partial reductions are supported by patterns.
29311+   // We only need to handle the 8-way partial reduction.
29312+   if (ResultVT.getScalarType() != MVT::i64 || OpVT.getScalarType() != MVT::i8)
29313+     return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Op)
29314+                              : Op;
29315+ 
29316+   EVT DotVT = ResultVT.isScalableVector() ? MVT::nxv4i32 : MVT::v4i32;
2927529317  SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, DotVT,
2927629318                                DAG.getConstant(0, DL, DotVT), LHS, RHS);
2927729319
29320+   SDValue Res;
2927829321  bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
29279-   if (Scalable &&
29280-       (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
29322+   if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
2928129323    unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
2928229324    unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
2928329325    SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
29284-     return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
29285-   }
29286- 
29287-   // Fold (nx)v4i32 into (nx)v2i64
29288-   auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29289-   if (IsUnsigned) {
29290-     DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
29291-     DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
29326+     Res = DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
2929229327  } else {
29293-     DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
29294-     DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
29328+     // Fold (nx)v4i32 into (nx)v2i64
29329+     auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29330+     if (IsUnsigned) {
29331+       DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
29332+       DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
29333+     } else {
29334+       DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
29335+       DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
29336+     }
29337+     auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
29338+     Res = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
2929529339  }
29296-   auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
29297-   return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
29340+ 
29341+   return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Res)
29342+                            : Res;
2929829343}
2929929344
2930029345SDValue
0 commit comments