@@ -1935,6 +1935,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
19351935 Custom);
19361936 setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv2i64,
19371937 Custom);
1938+
1939+ if (EnablePartialReduceNodes) {
1940+ static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
1941+ ISD::PARTIAL_REDUCE_UMLA};
1942+ // Must be lowered to SVE instructions.
1943+ setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v4i32, Custom);
1944+ setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v8i16, Custom);
1945+ setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
1946+ setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v8i16, Custom);
1947+ setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Custom);
1948+ setPartialReduceMLAAction(MLAOps, MVT::v8i16, MVT::v16i8, Custom);
1949+ }
19381950 }
19391951 }
19401952
@@ -2230,6 +2242,28 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
22302242 bool PreferNEON = VT.is64BitVector() || VT.is128BitVector();
22312243 bool PreferSVE = !PreferNEON && Subtarget->isSVEAvailable();
22322244
2245+ if (EnablePartialReduceNodes) {
2246+ static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
2247+ ISD::PARTIAL_REDUCE_UMLA};
2248+ unsigned NumElts = VT.getVectorNumElements();
2249+ if (VT.getVectorElementType() == MVT::i64) {
2250+ setPartialReduceMLAAction(MLAOps, VT,
2251+ MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
2252+ setPartialReduceMLAAction(
2253+ MLAOps, VT, MVT::getVectorVT(MVT::i16, NumElts * 4), Custom);
2254+ setPartialReduceMLAAction(
2255+ MLAOps, VT, MVT::getVectorVT(MVT::i32, NumElts * 2), Custom);
2256+ } else if (VT.getVectorElementType() == MVT::i32) {
2257+ setPartialReduceMLAAction(MLAOps, VT,
2258+ MVT::getVectorVT(MVT::i8, NumElts * 4), Custom);
2259+ setPartialReduceMLAAction(
2260+ MLAOps, VT, MVT::getVectorVT(MVT::i16, NumElts * 2), Custom);
2261+ } else if (VT.getVectorElementType() == MVT::i16) {
2262+ setPartialReduceMLAAction(MLAOps, VT,
2263+ MVT::getVectorVT(MVT::i8, NumElts * 2), Custom);
2264+ }
2265+ }
2266+
22332267 // Lower fixed length vector operations to scalable equivalents.
22342268 setOperationAction(ISD::ABDS, VT, Default);
22352269 setOperationAction(ISD::ABDU, VT, Default);
@@ -29229,50 +29263,61 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
2922929263SDValue
2923029264AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
2923129265 SelectionDAG &DAG) const {
29232- bool Scalable = Op.getValueType().isScalableVector();
29233-
29234- assert((!Scalable || Subtarget->isSVEorStreamingSVEAvailable()) &&
29235- "SVE or StreamingSVE must be available when using scalable vectors.");
29236- assert((Scalable || Subtarget->hasDotProd()) &&
29237- "Dotprod must be available when targeting NEON dot product "
29238- "instructions.");
29239-
2924029266 SDLoc DL(Op);
2924129267
2924229268 SDValue Acc = Op.getOperand(0);
2924329269 SDValue LHS = Op.getOperand(1);
2924429270 SDValue RHS = Op.getOperand(2);
2924529271 EVT ResultVT = Op.getValueType();
29272+ EVT OrigResultVT = ResultVT;
29273+ EVT OpVT = LHS.getValueType();
2924629274
29247- assert((Scalable && ResultVT == MVT::nxv2i64 &&
29248- LHS.getValueType() == MVT::nxv16i8) ||
29249- (!Scalable && ResultVT == MVT::v2i64 &&
29250- LHS.getValueType() == MVT::v16i8));
29275+ bool ConvertToScalable =
29276+ ResultVT.isFixedLengthVector() &&
29277+ useSVEForFixedLengthVectorVT(ResultVT, /*OverrideNEON=*/true);
2925129278
29252- EVT DotVT = Scalable ? MVT::nxv4i32 : MVT::v4i32;
29279+ if (ConvertToScalable) {
29280+ ResultVT = getContainerForFixedLengthVector(DAG, ResultVT);
29281+ OpVT = getContainerForFixedLengthVector(DAG, LHS.getValueType());
29282+ Acc = convertToScalableVector(DAG, ResultVT, Acc);
29283+ LHS = convertToScalableVector(DAG, OpVT, LHS);
29284+ RHS = convertToScalableVector(DAG, OpVT, RHS);
29285+ Op = DAG.getNode(Op.getOpcode(), DL, ResultVT, {Acc, LHS, RHS});
29286+ }
29287+
29288+ // Two-way and four-way partial reductions are supported by patterns.
29289+ // We only need to handle the 8-way partial reduction.
29290+ if (ResultVT.getScalarType() != MVT::i64 || OpVT.getScalarType() != MVT::i8)
29291+ return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Op)
29292+ : Op;
29293+
29294+ EVT DotVT = ResultVT.isScalableVector() ? MVT::nxv4i32 : MVT::v4i32;
2925329295 SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, DotVT,
2925429296 DAG.getConstant(0, DL, DotVT), LHS, RHS);
2925529297
29298+ SDValue Res;
2925629299 bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
29257- if (Scalable &&
29258- (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
29300+ if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
2925929301 unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
2926029302 unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
2926129303 SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
29262- return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
29263- }
29264-
29265- // Fold (nx)v4i32 into (nx)v2i64
29266- auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29267- if (IsUnsigned) {
29268- DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
29269- DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
29304+ Res = DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
2927029305 } else {
29271- DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
29272- DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
29306+ // Fold (nx)v4i32 into (nx)v2i64
29307+ auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29308+ if (IsUnsigned) {
29309+ DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
29310+ DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
29311+ } else {
29312+ DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
29313+ DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
29314+ }
29315+ auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
29316+ Res = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
2927329317 }
29274- auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
29275- return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
29318+
29319+ return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Res)
29320+ : Res;
2927629321}
2927729322
2927829323SDValue
0 commit comments