diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td index a4ed62bb5715c..7a615e3a80622 100644 --- a/llvm/include/llvm/Target/TargetSelectionDAG.td +++ b/llvm/include/llvm/Target/TargetSelectionDAG.td @@ -115,6 +115,9 @@ def SDTPtrAddOp : SDTypeProfile<1, 2, [ // ptradd def SDTIntBinOp : SDTypeProfile<1, 2, [ // add, and, or, xor, udiv, etc. SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisInt<0> ]>; +def SDTIntBinVPOp : SDTypeProfile<1, 4, [ + SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisInt<0>, SDTCisVec<0>, SDTCVecEltisVT<3, i1>, SDTCisSameNumEltsAs<0, 3>, SDTCisInt<4> +]>; def SDTIntShiftOp : SDTypeProfile<1, 2, [ // shl, sra, srl SDTCisSameAs<0, 1>, SDTCisInt<0>, SDTCisInt<2> ]>; @@ -423,6 +426,8 @@ def smullohi : SDNode<"ISD::SMUL_LOHI" , SDTIntBinHiLoOp, [SDNPCommutative]>; def umullohi : SDNode<"ISD::UMUL_LOHI" , SDTIntBinHiLoOp, [SDNPCommutative]>; def sdiv : SDNode<"ISD::SDIV" , SDTIntBinOp>; def udiv : SDNode<"ISD::UDIV" , SDTIntBinOp>; +def vp_sdiv : SDNode<"ISD::VP_SDIV" , SDTIntBinVPOp>; +def vp_udiv : SDNode<"ISD::VP_UDIV" , SDTIntBinVPOp>; def srem : SDNode<"ISD::SREM" , SDTIntBinOp>; def urem : SDNode<"ISD::UREM" , SDTIntBinOp>; def sdivrem : SDNode<"ISD::SDIVREM" , SDTIntBinHiLoOp>; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index aefbbe2534be2..1edbab78952ae 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1594,6 +1594,17 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::OR, VT, Custom); } + for (auto VT : {MVT::nxv4i32, MVT::nxv2i64}) { + setOperationAction(ISD::VP_SDIV, VT, Legal); + setOperationAction(ISD::VP_UDIV, VT, Legal); + } + // SVE doesn't have i8 and i16 DIV operations, so custom lower them to + // 32-bit operations. + for (auto VT : {MVT::nxv16i8, MVT::nxv8i16}) { + setOperationAction(ISD::VP_SDIV, VT, Custom); + setOperationAction(ISD::VP_UDIV, VT, Custom); + } + // Illegal unpacked integer vector types. for (auto VT : {MVT::nxv8i8, MVT::nxv4i16, MVT::nxv2i32}) { setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); @@ -7462,6 +7473,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, case ISD::SDIV: case ISD::UDIV: return LowerDIV(Op, DAG); + case ISD::VP_SDIV: + case ISD::VP_UDIV: + return LowerVP_DIV(Op, DAG); case ISD::SMIN: case ISD::UMIN: case ISD::SMAX: @@ -15870,6 +15884,39 @@ SDValue AArch64TargetLowering::LowerDIV(SDValue Op, SelectionDAG &DAG) const { return DAG.getNode(AArch64ISD::UZP1, DL, VT, ResultLoCast, ResultHiCast); } +SDValue AArch64TargetLowering::LowerVP_DIV(SDValue Op, + SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + SDLoc DL(Op); + bool Signed = Op.getOpcode() == ISD::VP_SDIV; + + // SVE doesn't have i8 and i16 DIV operations; widen them to 32-bit + // operations, and truncate the result. + EVT WidenedVT; + if (VT == MVT::nxv16i8) + WidenedVT = MVT::nxv8i16; + else if (VT == MVT::nxv8i16) + WidenedVT = MVT::nxv4i32; + else + llvm_unreachable("Unexpected Custom DIV operation"); + + auto [MaskLo, MaskHi] = DAG.SplitVector(Op.getOperand(2), DL); + auto [EVLLo, EVLHi] = DAG.SplitEVL(Op.getOperand(3), WidenedVT, DL); + unsigned UnpkLo = Signed ? AArch64ISD::SUNPKLO : AArch64ISD::UUNPKLO; + unsigned UnpkHi = Signed ? AArch64ISD::SUNPKHI : AArch64ISD::UUNPKHI; + SDValue Op0Lo = DAG.getNode(UnpkLo, DL, WidenedVT, Op.getOperand(0)); + SDValue Op1Lo = DAG.getNode(UnpkLo, DL, WidenedVT, Op.getOperand(1)); + SDValue Op0Hi = DAG.getNode(UnpkHi, DL, WidenedVT, Op.getOperand(0)); + SDValue Op1Hi = DAG.getNode(UnpkHi, DL, WidenedVT, Op.getOperand(1)); + SDValue ResultLo = + DAG.getNode(Op.getOpcode(), DL, WidenedVT, Op0Lo, Op1Lo, MaskLo, EVLLo); + SDValue ResultHi = + DAG.getNode(Op.getOpcode(), DL, WidenedVT, Op0Hi, Op1Hi, MaskHi, EVLHi); + SDValue ResultLoCast = DAG.getNode(AArch64ISD::NVCAST, DL, VT, ResultLo); + SDValue ResultHiCast = DAG.getNode(AArch64ISD::NVCAST, DL, VT, ResultHi); + return DAG.getNode(AArch64ISD::UZP1, DL, VT, ResultLoCast, ResultHiCast); +} + bool AArch64TargetLowering::shouldExpandBuildVectorWithShuffles( EVT VT, unsigned DefinedValues) const { if (!Subtarget->isNeonAvailable()) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 78d6a507b80d3..0251ab594e488 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -706,6 +706,7 @@ class AArch64TargetLowering : public TargetLowering { SDValue LowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const; SDValue LowerGET_ACTIVE_LANE_MASK(SDValue Op, SelectionDAG &DAG) const; SDValue LowerDIV(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerVP_DIV(SDValue Op, SelectionDAG &DAG) const; SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) const; SDValue LowerVectorSRA_SRL_SHL(SDValue Op, SelectionDAG &DAG) const; SDValue LowerShiftParts(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index 509dd8b73a017..ee17fef3a3b3a 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -700,8 +700,8 @@ let Predicates = [HasSVE_or_SME] in { defm SDIVR_ZPmZ : sve_int_bin_pred_arit_2_div<0b110, "sdivr", "SDIVR_ZPZZ", int_aarch64_sve_sdivr, DestructiveBinaryCommWithRev, "SDIV_ZPmZ", /*isReverseInstr*/ 1>; defm UDIVR_ZPmZ : sve_int_bin_pred_arit_2_div<0b111, "udivr", "UDIVR_ZPZZ", int_aarch64_sve_udivr, DestructiveBinaryCommWithRev, "UDIV_ZPmZ", /*isReverseInstr*/ 1>; - defm SDIV_ZPZZ : sve_int_bin_pred_sd; - defm UDIV_ZPZZ : sve_int_bin_pred_sd; + defm SDIV_ZPZZ : sve_int_bin_pred_sd; + defm UDIV_ZPZZ : sve_int_bin_pred_sd; defm SDOT_ZZZ : sve_intx_dot<0b0, "sdot", AArch64sdot>; defm UDOT_ZZZ : sve_intx_dot<0b1, "udot", AArch64udot>; diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h index 9c96fdd427814..ea6cf1a7e21d1 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -156,6 +156,17 @@ class AArch64TTIImpl final : public BasicTTIImplBase { bool isVScaleKnownToBeAPowerOfTwo() const override { return true; } + TargetTransformInfo::VPLegalization + getVPLegalizationStrategy(const VPIntrinsic &PI) const override { + using VPLegalization = TargetTransformInfo::VPLegalization; + switch (PI.getIntrinsicID()) { + case Intrinsic::vp_sdiv: + case Intrinsic::vp_udiv: + return VPLegalization(VPLegalization::Discard, VPLegalization::Legal); + } + return BaseT::getVPLegalizationStrategy(PI); + } + bool shouldMaximizeVectorBandwidth( TargetTransformInfo::RegisterKind K) const override; diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td index a3a7d0f74e1bc..ada6a47590ed2 100644 --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -9788,12 +9788,17 @@ multiclass sve_int_bin_pred_bhsd { } // As sve_int_bin_pred but when only i32 and i64 vector types are required. -multiclass sve_int_bin_pred_sd { +multiclass sve_int_bin_pred_sd { def _S_UNDEF : PredTwoOpPseudo; def _D_UNDEF : PredTwoOpPseudo; def : SVE_3_Op_Pat(NAME # _S_UNDEF)>; def : SVE_3_Op_Pat(NAME # _D_UNDEF)>; + + def : Pat<(nxv4i32 (vp_op nxv4i32:$lhs, nxv4i32:$rhs, nxv4i1:$pred, (i32 srcvalue))), + (!cast(NAME # _S_UNDEF) $pred, $lhs, $rhs)>; + def : Pat<(nxv2i64 (vp_op nxv2i64:$lhs, nxv2i64:$rhs, nxv2i1:$pred, (i32 srcvalue))), + (!cast(NAME # _D_UNDEF) $pred, $lhs, $rhs)>; } // Predicated pseudo integer two operand instructions. Second operand is an diff --git a/llvm/test/CodeGen/AArch64/sve-vp-div.ll b/llvm/test/CodeGen/AArch64/sve-vp-div.ll new file mode 100644 index 0000000000000..920b4308ecfdd --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-vp-div.ll @@ -0,0 +1,144 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc -mtriple=aarch64 -mattr=+sve < %s | FileCheck %s + +define @sdiv_evl_max( %x, %y, %mask) { +; CHECK-LABEL: sdiv_evl_max: +; CHECK: // %bb.0: +; CHECK-NEXT: sdiv z0.d, p0/m, z0.d, z1.d +; CHECK-NEXT: ret + %vscale = call i32 @llvm.vscale() + %evl = mul i32 %vscale, 2 + %z = call @llvm.vp.sdiv( %x, %y, %mask, i32 %evl) + ret %z +} + +define @sdiv_nxv2i64( %x, %y, %mask, i32 %evl) { +; CHECK-LABEL: sdiv_nxv2i64: +; CHECK: // %bb.0: +; CHECK-NEXT: whilelo p1.d, wzr, w0 +; CHECK-NEXT: and p0.b, p1/z, p1.b, p0.b +; CHECK-NEXT: sdiv z0.d, p0/m, z0.d, z1.d +; CHECK-NEXT: ret + %z = call @llvm.vp.sdiv( %x, %y, %mask, i32 %evl) + ret %z +} + +define @sdiv_nxv4i32( %x, %y, %mask, i32 %evl) { +; CHECK-LABEL: sdiv_nxv4i32: +; CHECK: // %bb.0: +; CHECK-NEXT: whilelo p1.s, wzr, w0 +; CHECK-NEXT: and p0.b, p1/z, p1.b, p0.b +; CHECK-NEXT: sdiv z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %z = call @llvm.vp.sdiv( %x, %y, %mask, i32 %evl) + ret %z +} + +define @sdiv_nxv8i16( %x, %y, %mask, i32 %evl) { +; CHECK-LABEL: sdiv_nxv8i16: +; CHECK: // %bb.0: +; CHECK-NEXT: whilelo p1.h, wzr, w0 +; CHECK-NEXT: sunpkhi z2.s, z1.h +; CHECK-NEXT: sunpkhi z3.s, z0.h +; CHECK-NEXT: sunpklo z1.s, z1.h +; CHECK-NEXT: sunpklo z0.s, z0.h +; CHECK-NEXT: and p0.b, p1/z, p1.b, p0.b +; CHECK-NEXT: punpkhi p1.h, p0.b +; CHECK-NEXT: punpklo p0.h, p0.b +; CHECK-NEXT: sdivr z2.s, p1/m, z2.s, z3.s +; CHECK-NEXT: sdiv z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: uzp1 z0.h, z0.h, z2.h +; CHECK-NEXT: ret + %z = call @llvm.vp.sdiv( %x, %y, %mask, i32 %evl) + ret %z +} + +define @sdiv_nxv16i8( %x, %y, %mask, i32 %evl) { +; CHECK-LABEL: sdiv_nxv16i8: +; CHECK: // %bb.0: +; CHECK-NEXT: whilelo p1.h, wzr, w0 +; CHECK-NEXT: sunpkhi z2.s, z1.h +; CHECK-NEXT: sunpkhi z3.s, z0.h +; CHECK-NEXT: sunpklo z1.s, z1.h +; CHECK-NEXT: sunpklo z0.s, z0.h +; CHECK-NEXT: and p0.b, p1/z, p1.b, p0.b +; CHECK-NEXT: punpkhi p1.h, p0.b +; CHECK-NEXT: punpklo p0.h, p0.b +; CHECK-NEXT: sdivr z2.s, p1/m, z2.s, z3.s +; CHECK-NEXT: sdiv z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: uzp1 z0.h, z0.h, z2.h +; CHECK-NEXT: ret + %z = call @llvm.vp.sdiv( %x, %y, %mask, i32 %evl) + ret %z +} + +define @udiv_evl_max( %x, %y, %mask) { +; CHECK-LABEL: udiv_evl_max: +; CHECK: // %bb.0: +; CHECK-NEXT: udiv z0.d, p0/m, z0.d, z1.d +; CHECK-NEXT: ret + %vscale = call i32 @llvm.vscale() + %evl = mul i32 %vscale, 2 + %z = call @llvm.vp.udiv( %x, %y, %mask, i32 %evl) + ret %z +} + +define @udiv_nxv2i64( %x, %y, %mask, i32 %evl) { +; CHECK-LABEL: udiv_nxv2i64: +; CHECK: // %bb.0: +; CHECK-NEXT: whilelo p1.d, wzr, w0 +; CHECK-NEXT: and p0.b, p1/z, p1.b, p0.b +; CHECK-NEXT: udiv z0.d, p0/m, z0.d, z1.d +; CHECK-NEXT: ret + %z = call @llvm.vp.udiv( %x, %y, %mask, i32 %evl) + ret %z +} + +define @udiv_nxv4i32( %x, %y, %mask, i32 %evl) { +; CHECK-LABEL: udiv_nxv4i32: +; CHECK: // %bb.0: +; CHECK-NEXT: whilelo p1.s, wzr, w0 +; CHECK-NEXT: and p0.b, p1/z, p1.b, p0.b +; CHECK-NEXT: udiv z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %z = call @llvm.vp.udiv( %x, %y, %mask, i32 %evl) + ret %z +} + +define @udiv_nxv8i16( %x, %y, %mask, i32 %evl) { +; CHECK-LABEL: udiv_nxv8i16: +; CHECK: // %bb.0: +; CHECK-NEXT: whilelo p1.h, wzr, w0 +; CHECK-NEXT: uunpkhi z2.s, z1.h +; CHECK-NEXT: uunpkhi z3.s, z0.h +; CHECK-NEXT: uunpklo z1.s, z1.h +; CHECK-NEXT: uunpklo z0.s, z0.h +; CHECK-NEXT: and p0.b, p1/z, p1.b, p0.b +; CHECK-NEXT: punpkhi p1.h, p0.b +; CHECK-NEXT: punpklo p0.h, p0.b +; CHECK-NEXT: udivr z2.s, p1/m, z2.s, z3.s +; CHECK-NEXT: udiv z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: uzp1 z0.h, z0.h, z2.h +; CHECK-NEXT: ret + %z = call @llvm.vp.udiv( %x, %y, %mask, i32 %evl) + ret %z +} + +define @udiv_nxv16i8( %x, %y, %mask, i32 %evl) { +; CHECK-LABEL: udiv_nxv16i8: +; CHECK: // %bb.0: +; CHECK-NEXT: whilelo p1.h, wzr, w0 +; CHECK-NEXT: uunpkhi z2.s, z1.h +; CHECK-NEXT: uunpkhi z3.s, z0.h +; CHECK-NEXT: uunpklo z1.s, z1.h +; CHECK-NEXT: uunpklo z0.s, z0.h +; CHECK-NEXT: and p0.b, p1/z, p1.b, p0.b +; CHECK-NEXT: punpkhi p1.h, p0.b +; CHECK-NEXT: punpklo p0.h, p0.b +; CHECK-NEXT: udivr z2.s, p1/m, z2.s, z3.s +; CHECK-NEXT: udiv z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: uzp1 z0.h, z0.h, z2.h +; CHECK-NEXT: ret + %z = call @llvm.vp.udiv( %x, %y, %mask, i32 %evl) + ret %z +}