@@ -1456,6 +1456,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
14561456 // FADDP custom lowering
14571457 for (MVT VT : { MVT::v16f16, MVT::v8f32, MVT::v4f64 })
14581458 setOperationAction(ISD::FADD, VT, Custom);
1459+
1460+ if (EnablePartialReduceNodes && Subtarget->hasDotProd()) {
1461+ setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Legal);
1462+ setPartialReduceMLAAction(MVT::v2i32, MVT::v8i8, Legal);
1463+ setPartialReduceMLAAction(MVT::v2i64, MVT::v16i8, Custom);
1464+ }
1465+
14591466 } else /* !isNeonAvailable */ {
14601467 for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
14611468 for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op)
@@ -29528,37 +29535,60 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
2952829535}
2952929536
2953029537/// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type pairing
29531- /// of nxv2i64/nxv16i8 , we cannot directly lower it to a (u|s)dot. We can
29538+ /// of (nx)v2i64/(nx)v16i8 , we cannot directly lower it to a (u|s)dot. We can
2953229539/// however still make use of the dot product instruction by instead
29533- /// accumulating over two steps: nxv16i8 -> nxv4i32 -> nxv2i64.
29540+ /// accumulating over two steps: (nx)v16i8 -> (nx)v4i32 -> (nx)v2i64.
29541+ /// If available, make use of the (U|S)ADDW(B|T) instructions, otherwise
29542+ /// the following pattern is emitted:
29543+ /// add(add(Acc, ext(EXTRACT_SUBVECTOR(N, 0)), ext(EXTRACT_SUBVECTOR(N,
29544+ /// NTy/2))))
2953429545SDValue
2953529546AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
2953629547 SelectionDAG &DAG) const {
29548+ bool Scalable = Op.getValueType().isScalableVector();
29549+
29550+ assert((!Scalable || Subtarget->isSVEorStreamingSVEAvailable()) &&
29551+ "SVE or StreamingSVE must be available when using scalable vectors.");
29552+ assert((Scalable || Subtarget->hasDotProd()) &&
29553+ "Dotprod must be available when targeting NEON dot product "
29554+ "instructions.");
29555+
2953729556 SDLoc DL(Op);
2953829557
2953929558 SDValue Acc = Op.getOperand(0);
2954029559 SDValue LHS = Op.getOperand(1);
2954129560 SDValue RHS = Op.getOperand(2);
2954229561 EVT ResultVT = Op.getValueType();
29543- assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
2954429562
29545- SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
29546- DAG.getConstant(0, DL, MVT::nxv4i32), LHS, RHS);
29563+ assert((Scalable && ResultVT == MVT::nxv2i64 &&
29564+ LHS.getValueType() == MVT::nxv16i8) ||
29565+ (!Scalable && ResultVT == MVT::v2i64 &&
29566+ LHS.getValueType() == MVT::v16i8));
29567+
29568+ EVT DotVT = Scalable ? MVT::nxv4i32 : MVT::v4i32;
29569+ SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, DotVT,
29570+ DAG.getConstant(0, DL, DotVT), LHS, RHS);
2954729571
2954829572 bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
29549- if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
29573+ if (Scalable &&
29574+ (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
2955029575 unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
2955129576 unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
2955229577 SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
2955329578 return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
2955429579 }
2955529580
29556- unsigned LoOpcode = IsUnsigned ? AArch64ISD::UUNPKLO : AArch64ISD::SUNPKLO;
29557- unsigned HiOpcode = IsUnsigned ? AArch64ISD::UUNPKHI : AArch64ISD::SUNPKHI;
29558- auto Lo = DAG.getNode(LoOpcode, DL, ResultVT, DotNode);
29559- auto Hi = DAG.getNode(HiOpcode, DL, ResultVT, DotNode);
29560- auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
29561- return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
29581+ // Fold (nx)v4i32 into (nx)v2i64
29582+ auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29583+ if (IsUnsigned) {
29584+ DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
29585+ DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
29586+ } else {
29587+ DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
29588+ DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
29589+ }
29590+ auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
29591+ return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
2956229592}
2956329593
2956429594SDValue
0 commit comments