@@ -1935,6 +1935,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
19351935 Custom);
19361936 setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv2i64,
19371937 Custom);
1938+
1939+ if (EnablePartialReduceNodes) {
1940+ static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
1941+ ISD::PARTIAL_REDUCE_UMLA};
1942+ // Must be lowered to SVE instructions.
1943+ setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v4i32, Custom);
1944+ setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v8i16, Custom);
1945+ setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
1946+ setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v8i16, Custom);
1947+ setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Custom);
1948+ setPartialReduceMLAAction(MLAOps, MVT::v8i16, MVT::v16i8, Custom);
1949+ }
19381950 }
19391951 }
19401952
@@ -2230,6 +2242,28 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
22302242 bool PreferNEON = VT.is64BitVector() || VT.is128BitVector();
22312243 bool PreferSVE = !PreferNEON && Subtarget->isSVEAvailable();
22322244
2245+ if (EnablePartialReduceNodes) {
2246+ static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
2247+ ISD::PARTIAL_REDUCE_UMLA};
2248+ unsigned NumElts = VT.getVectorNumElements();
2249+ if (VT.getVectorElementType() == MVT::i64) {
2250+ setPartialReduceMLAAction(MLAOps, VT,
2251+ MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
2252+ setPartialReduceMLAAction(
2253+ MLAOps, VT, MVT::getVectorVT(MVT::i16, NumElts * 4), Custom);
2254+ setPartialReduceMLAAction(
2255+ MLAOps, VT, MVT::getVectorVT(MVT::i32, NumElts * 2), Custom);
2256+ } else if (VT.getVectorElementType() == MVT::i32) {
2257+ setPartialReduceMLAAction(MLAOps, VT,
2258+ MVT::getVectorVT(MVT::i8, NumElts * 4), Custom);
2259+ setPartialReduceMLAAction(
2260+ MLAOps, VT, MVT::getVectorVT(MVT::i16, NumElts * 2), Custom);
2261+ } else if (VT.getVectorElementType() == MVT::i16) {
2262+ setPartialReduceMLAAction(MLAOps, VT,
2263+ MVT::getVectorVT(MVT::i8, NumElts * 2), Custom);
2264+ }
2265+ }
2266+
22332267 // Lower fixed length vector operations to scalable equivalents.
22342268 setOperationAction(ISD::ABDS, VT, Default);
22352269 setOperationAction(ISD::ABDU, VT, Default);
@@ -29251,50 +29285,61 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
2925129285SDValue
2925229286AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
2925329287 SelectionDAG &DAG) const {
29254- bool Scalable = Op.getValueType().isScalableVector();
29255-
29256- assert((!Scalable || Subtarget->isSVEorStreamingSVEAvailable()) &&
29257- "SVE or StreamingSVE must be available when using scalable vectors.");
29258- assert((Scalable || Subtarget->hasDotProd()) &&
29259- "Dotprod must be available when targeting NEON dot product "
29260- "instructions.");
29261-
2926229288 SDLoc DL(Op);
2926329289
2926429290 SDValue Acc = Op.getOperand(0);
2926529291 SDValue LHS = Op.getOperand(1);
2926629292 SDValue RHS = Op.getOperand(2);
2926729293 EVT ResultVT = Op.getValueType();
29294+ EVT OrigResultVT = ResultVT;
29295+ EVT OpVT = LHS.getValueType();
2926829296
29269- assert((Scalable && ResultVT == MVT::nxv2i64 &&
29270- LHS.getValueType() == MVT::nxv16i8) ||
29271- (!Scalable && ResultVT == MVT::v2i64 &&
29272- LHS.getValueType() == MVT::v16i8));
29297+ bool ConvertToScalable =
29298+ ResultVT.isFixedLengthVector() &&
29299+ useSVEForFixedLengthVectorVT(ResultVT, /*OverrideNEON=*/true);
2927329300
29274- EVT DotVT = Scalable ? MVT::nxv4i32 : MVT::v4i32;
29301+ if (ConvertToScalable) {
29302+ ResultVT = getContainerForFixedLengthVector(DAG, ResultVT);
29303+ OpVT = getContainerForFixedLengthVector(DAG, LHS.getValueType());
29304+ Acc = convertToScalableVector(DAG, ResultVT, Acc);
29305+ LHS = convertToScalableVector(DAG, OpVT, LHS);
29306+ RHS = convertToScalableVector(DAG, OpVT, RHS);
29307+ Op = DAG.getNode(Op.getOpcode(), DL, ResultVT, {Acc, LHS, RHS});
29308+ }
29309+
29310+ // Two-way and four-way partial reductions are supported by patterns.
29311+ // We only need to handle the 8-way partial reduction.
29312+ if (ResultVT.getScalarType() != MVT::i64 || OpVT.getScalarType() != MVT::i8)
29313+ return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Op)
29314+ : Op;
29315+
29316+ EVT DotVT = ResultVT.isScalableVector() ? MVT::nxv4i32 : MVT::v4i32;
2927529317 SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, DotVT,
2927629318 DAG.getConstant(0, DL, DotVT), LHS, RHS);
2927729319
29320+ SDValue Res;
2927829321 bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
29279- if (Scalable &&
29280- (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
29322+ if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
2928129323 unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
2928229324 unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
2928329325 SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
29284- return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
29285- }
29286-
29287- // Fold (nx)v4i32 into (nx)v2i64
29288- auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29289- if (IsUnsigned) {
29290- DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
29291- DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
29326+ Res = DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
2929229327 } else {
29293- DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
29294- DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
29328+ // Fold (nx)v4i32 into (nx)v2i64
29329+ auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29330+ if (IsUnsigned) {
29331+ DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
29332+ DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
29333+ } else {
29334+ DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
29335+ DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
29336+ }
29337+ auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
29338+ Res = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
2929529339 }
29296- auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
29297- return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
29340+
29341+ return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Res)
29342+ : Res;
2929829343}
2929929344
2930029345SDValue
0 commit comments