diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 9d1c3d4eddc88..ff577f238d183 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -2768,6 +2768,10 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(AArch64ISD::UADDV) MAKE_CASE(AArch64ISD::UADDLV) MAKE_CASE(AArch64ISD::SADDLV) + MAKE_CASE(AArch64ISD::SADDWT) + MAKE_CASE(AArch64ISD::SADDWB) + MAKE_CASE(AArch64ISD::UADDWT) + MAKE_CASE(AArch64ISD::UADDWB) MAKE_CASE(AArch64ISD::SDOT) MAKE_CASE(AArch64ISD::UDOT) MAKE_CASE(AArch64ISD::USDOT) @@ -21907,17 +21911,10 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N, return SDValue(); bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND; - auto BottomIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwb - : Intrinsic::aarch64_sve_uaddwb; - auto TopIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwt - : Intrinsic::aarch64_sve_uaddwt; - - auto BottomID = DAG.getTargetConstant(BottomIntrinsic, DL, AccElemVT); - auto BottomNode = - DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccVT, BottomID, Acc, Input); - auto TopID = DAG.getTargetConstant(TopIntrinsic, DL, AccElemVT); - return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccVT, TopID, BottomNode, - Input); + auto BottomOpcode = InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB; + auto TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT; + auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input); + return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input); } static SDValue performIntrinsicCombine(SDNode *N, @@ -22097,6 +22094,18 @@ static SDValue performIntrinsicCombine(SDNode *N, case Intrinsic::aarch64_sve_bic_u: return DAG.getNode(AArch64ISD::BIC, SDLoc(N), N->getValueType(0), N->getOperand(2), N->getOperand(3)); + case Intrinsic::aarch64_sve_saddwb: + return DAG.getNode(AArch64ISD::SADDWB, SDLoc(N), N->getValueType(0), + N->getOperand(1), N->getOperand(2)); + case Intrinsic::aarch64_sve_saddwt: + return DAG.getNode(AArch64ISD::SADDWT, SDLoc(N), N->getValueType(0), + N->getOperand(1), N->getOperand(2)); + case Intrinsic::aarch64_sve_uaddwb: + return DAG.getNode(AArch64ISD::UADDWB, SDLoc(N), N->getValueType(0), + N->getOperand(1), N->getOperand(2)); + case Intrinsic::aarch64_sve_uaddwt: + return DAG.getNode(AArch64ISD::UADDWT, SDLoc(N), N->getValueType(0), + N->getOperand(1), N->getOperand(2)); case Intrinsic::aarch64_sve_eor_u: return DAG.getNode(ISD::XOR, SDLoc(N), N->getValueType(0), N->getOperand(2), N->getOperand(3)); @@ -29702,6 +29711,27 @@ void AArch64TargetLowering::verifyTargetSDNode(const SDNode *N) const { switch (N->getOpcode()) { default: break; + case AArch64ISD::SADDWT: + case AArch64ISD::SADDWB: + case AArch64ISD::UADDWT: + case AArch64ISD::UADDWB: { + assert(N->getNumValues() == 1 && "Expected one result!"); + assert(N->getNumOperands() == 2 && "Expected two operands!"); + EVT VT = N->getValueType(0); + EVT Op0VT = N->getOperand(0).getValueType(); + EVT Op1VT = N->getOperand(1).getValueType(); + assert(VT.isVector() && Op0VT.isVector() && Op1VT.isVector() && + VT.isInteger() && Op0VT.isInteger() && Op1VT.isInteger() && + "Expected integer vectors!"); + assert(VT == Op0VT && + "Expected result and first input to have the same type!"); + assert(Op0VT.getSizeInBits() == Op1VT.getSizeInBits() && + "Expected vectors of equal size!"); + assert(Op0VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount() && + "Expected result vector and first input vector to have half the " + "lanes of the second input vector!"); + break; + } case AArch64ISD::SUNPKLO: case AArch64ISD::SUNPKHI: case AArch64ISD::UUNPKLO: diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index d11da64d3f84e..176ad57a6ed72 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -273,6 +273,12 @@ enum NodeType : unsigned { UADDLV, SADDLV, + // Wide adds + SADDWT, + SADDWB, + UADDWT, + UADDWB, + // Add Pairwise of two vectors ADDP, // Add Long Pairwise diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index 4f146b3ee59e9..659d5e8b414ce 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -430,6 +430,13 @@ def SDT_AArch64Arith_Unpred : SDTypeProfile<1, 2, [ def AArch64bic_node : SDNode<"AArch64ISD::BIC", SDT_AArch64Arith_Unpred>; +def SDT_AArch64addw : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisVec<1>]>; + +def AArch64saddwt : SDNode<"AArch64ISD::SADDWT", SDT_AArch64addw>; +def AArch64saddwb : SDNode<"AArch64ISD::SADDWB", SDT_AArch64addw>; +def AArch64uaddwt : SDNode<"AArch64ISD::UADDWT", SDT_AArch64addw>; +def AArch64uaddwb : SDNode<"AArch64ISD::UADDWB", SDT_AArch64addw>; + def AArch64bic : PatFrags<(ops node:$op1, node:$op2), [(and node:$op1, (xor node:$op2, (splat_vector (i32 -1)))), (and node:$op1, (xor node:$op2, (splat_vector (i64 -1)))), @@ -3674,10 +3681,10 @@ let Predicates = [HasSVE2orSME] in { defm UABDLT_ZZZ : sve2_wide_int_arith_long<0b01111, "uabdlt", int_aarch64_sve_uabdlt>; // SVE2 integer add/subtract wide - defm SADDWB_ZZZ : sve2_wide_int_arith_wide<0b000, "saddwb", int_aarch64_sve_saddwb>; - defm SADDWT_ZZZ : sve2_wide_int_arith_wide<0b001, "saddwt", int_aarch64_sve_saddwt>; - defm UADDWB_ZZZ : sve2_wide_int_arith_wide<0b010, "uaddwb", int_aarch64_sve_uaddwb>; - defm UADDWT_ZZZ : sve2_wide_int_arith_wide<0b011, "uaddwt", int_aarch64_sve_uaddwt>; + defm SADDWB_ZZZ : sve2_wide_int_arith_wide<0b000, "saddwb", AArch64saddwb>; + defm SADDWT_ZZZ : sve2_wide_int_arith_wide<0b001, "saddwt", AArch64saddwt>; + defm UADDWB_ZZZ : sve2_wide_int_arith_wide<0b010, "uaddwb", AArch64uaddwb>; + defm UADDWT_ZZZ : sve2_wide_int_arith_wide<0b011, "uaddwt", AArch64uaddwt>; defm SSUBWB_ZZZ : sve2_wide_int_arith_wide<0b100, "ssubwb", int_aarch64_sve_ssubwb>; defm SSUBWT_ZZZ : sve2_wide_int_arith_wide<0b101, "ssubwt", int_aarch64_sve_ssubwt>; defm USUBWB_ZZZ : sve2_wide_int_arith_wide<0b110, "usubwb", int_aarch64_sve_usubwb>;