@@ -1872,6 +1872,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
18721872 setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
18731873 }
18741874
1875+ if (EnablePartialReduceNodes && Subtarget->hasNEON() &&
1876+ Subtarget->hasDotProd()) {
1877+ setPartialReduceMLAAction(MVT::v2i64, MVT::v8i16, Legal);
1878+ setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Legal);
1879+ setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Legal);
1880+ setPartialReduceMLAAction(MVT::v2i32, MVT::v8i8, Legal);
1881+ setPartialReduceMLAAction(MVT::v2i64, MVT::v16i8, Custom);
1882+ }
1883+
18751884 // Handle operations that are only available in non-streaming SVE mode.
18761885 if (Subtarget->isSVEAvailable()) {
18771886 for (auto VT : {MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, MVT::nxv2i64,
@@ -7743,8 +7752,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
77437752 case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
77447753 return LowerVECTOR_HISTOGRAM(Op, DAG);
77457754 case ISD::PARTIAL_REDUCE_SMLA:
7746- case ISD::PARTIAL_REDUCE_UMLA:
7747- case ISD::PARTIAL_REDUCE_SMLA: {
7755+ case ISD::PARTIAL_REDUCE_UMLA: {
77487756 if (SDValue Result = LowerPARTIAL_REDUCE_MLA(Op, DAG))
77497757 return Result;
77507758 return expandPartialReduceMLA(Op.getNode(), DAG);
@@ -27575,8 +27583,7 @@ void AArch64TargetLowering::ReplaceNodeResults(
2757527583 return;
2757627584 case ISD::PARTIAL_REDUCE_UMLA:
2757727585 case ISD::PARTIAL_REDUCE_SMLA: {
27578- SDValue Res;
27579- if (Res = LowerPARTIAL_REDUCE_MLA(SDValue(N, 0), DAG))
27586+ if (SDValue Res = LowerPARTIAL_REDUCE_MLA(SDValue(N, 0), DAG))
2758027587 Results.push_back(Res);
2758127588 else
2758227589 Results.push_back(expandPartialReduceMLA(N, DAG));
@@ -29531,9 +29538,9 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
2953129538}
2953229539
2953329540/// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type pairing
29534- /// of nxv2i64/nxv16i8 , we cannot directly lower it to a (u|s)dot. We can
29541+ /// of v2i64/v16i8 , we cannot directly lower it to a (u|s)dot. We can
2953529542/// however still make use of the dot product instruction by instead
29536- /// accumulating over two steps: nxv16i8 -> nxv4i32 -> nxv2i64 .
29543+ /// accumulating over two steps: v16i8 -> v4i32 -> v2i64 .
2953729544SDValue
2953829545AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
2953929546 SelectionDAG &DAG) const {
@@ -29568,12 +29575,27 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
2956829575 return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
2956929576 }
2957029577
29571- unsigned LoOpcode = IsUnsigned ? AArch64ISD::UUNPKLO : AArch64ISD::SUNPKLO;
29572- unsigned HiOpcode = IsUnsigned ? AArch64ISD::UUNPKHI : AArch64ISD::SUNPKHI;
29573- auto Lo = DAG.getNode(LoOpcode, DL, ResultVT, DotNode);
29574- auto Hi = DAG.getNode(HiOpcode, DL, ResultVT, DotNode);
29575- auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
29576- return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
29578+ if (Scalable) {
29579+ unsigned LoOpcode = IsUnsigned ? AArch64ISD::UUNPKLO : AArch64ISD::SUNPKLO;
29580+ unsigned HiOpcode = IsUnsigned ? AArch64ISD::UUNPKHI : AArch64ISD::SUNPKHI;
29581+ auto Lo = DAG.getNode(LoOpcode, DL, ResultVT, DotNode);
29582+ auto Hi = DAG.getNode(HiOpcode, DL, ResultVT, DotNode);
29583+ auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
29584+ return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
29585+ }
29586+
29587+ // Fold v4i32 into v2i64
29588+ // SDValues
29589+ auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29590+ if (IsUnsigned) {
29591+ DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, MVT::v2i64);
29592+ DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, MVT::v2i64);
29593+ } else {
29594+ DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, MVT::v2i64);
29595+ DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, MVT::v2i64);
29596+ }
29597+ auto Lo = DAG.getNode(ISD::ADD, DL, MVT::v2i64, Acc, DotNodeLo);
29598+ return DAG.getNode(ISD::ADD, DL, MVT::v2i64, Lo, DotNodeHi);
2957729599}
2957829600
2957929601SDValue
0 commit comments