Skip to content

Commit 8a2f110

Browse files
committed
[AArch64] Support scalable vp.udiv/vp.sdiv with SVE
1 parent 98f4b77 commit 8a2f110

File tree

7 files changed

+216
-3
lines changed

7 files changed

+216
-3
lines changed

llvm/include/llvm/Target/TargetSelectionDAG.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@ def SDTPtrAddOp : SDTypeProfile<1, 2, [ // ptradd
115115
def SDTIntBinOp : SDTypeProfile<1, 2, [ // add, and, or, xor, udiv, etc.
116116
SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisInt<0>
117117
]>;
118+
def SDTIntBinVPOp : SDTypeProfile<1, 4, [
119+
SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisInt<0>, SDTCisVec<0>, SDTCVecEltisVT<3, i1>, SDTCisSameNumEltsAs<0, 3>, SDTCisInt<4>
120+
]>;
118121
def SDTIntShiftOp : SDTypeProfile<1, 2, [ // shl, sra, srl
119122
SDTCisSameAs<0, 1>, SDTCisInt<0>, SDTCisInt<2>
120123
]>;
@@ -423,6 +426,8 @@ def smullohi : SDNode<"ISD::SMUL_LOHI" , SDTIntBinHiLoOp, [SDNPCommutative]>;
423426
def umullohi : SDNode<"ISD::UMUL_LOHI" , SDTIntBinHiLoOp, [SDNPCommutative]>;
424427
def sdiv : SDNode<"ISD::SDIV" , SDTIntBinOp>;
425428
def udiv : SDNode<"ISD::UDIV" , SDTIntBinOp>;
429+
def vp_sdiv : SDNode<"ISD::VP_SDIV" , SDTIntBinVPOp>;
430+
def vp_udiv : SDNode<"ISD::VP_UDIV" , SDTIntBinVPOp>;
426431
def srem : SDNode<"ISD::SREM" , SDTIntBinOp>;
427432
def urem : SDNode<"ISD::UREM" , SDTIntBinOp>;
428433
def sdivrem : SDNode<"ISD::SDIVREM" , SDTIntBinHiLoOp>;

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1594,6 +1594,17 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
15941594
setOperationAction(ISD::OR, VT, Custom);
15951595
}
15961596

1597+
for (auto VT : {MVT::nxv4i32, MVT::nxv2i64}) {
1598+
setOperationAction(ISD::VP_SDIV, VT, Legal);
1599+
setOperationAction(ISD::VP_UDIV, VT, Legal);
1600+
}
1601+
// SVE doesn't have i8 and i16 DIV operations, so custom lower them to
1602+
// 32-bit operations.
1603+
for (auto VT : {MVT::nxv16i8, MVT::nxv8i16}) {
1604+
setOperationAction(ISD::VP_SDIV, VT, Custom);
1605+
setOperationAction(ISD::VP_UDIV, VT, Custom);
1606+
}
1607+
15971608
// Illegal unpacked integer vector types.
15981609
for (auto VT : {MVT::nxv8i8, MVT::nxv4i16, MVT::nxv2i32}) {
15991610
setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
@@ -7462,6 +7473,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
74627473
case ISD::SDIV:
74637474
case ISD::UDIV:
74647475
return LowerDIV(Op, DAG);
7476+
case ISD::VP_SDIV:
7477+
case ISD::VP_UDIV:
7478+
return LowerVP_DIV(Op, DAG);
74657479
case ISD::SMIN:
74667480
case ISD::UMIN:
74677481
case ISD::SMAX:
@@ -15870,6 +15884,39 @@ SDValue AArch64TargetLowering::LowerDIV(SDValue Op, SelectionDAG &DAG) const {
1587015884
return DAG.getNode(AArch64ISD::UZP1, DL, VT, ResultLoCast, ResultHiCast);
1587115885
}
1587215886

15887+
SDValue AArch64TargetLowering::LowerVP_DIV(SDValue Op,
15888+
SelectionDAG &DAG) const {
15889+
EVT VT = Op.getValueType();
15890+
SDLoc DL(Op);
15891+
bool Signed = Op.getOpcode() == ISD::VP_SDIV;
15892+
15893+
// SVE doesn't have i8 and i16 DIV operations; widen them to 32-bit
15894+
// operations, and truncate the result.
15895+
EVT WidenedVT;
15896+
if (VT == MVT::nxv16i8)
15897+
WidenedVT = MVT::nxv8i16;
15898+
else if (VT == MVT::nxv8i16)
15899+
WidenedVT = MVT::nxv4i32;
15900+
else
15901+
llvm_unreachable("Unexpected Custom DIV operation");
15902+
15903+
auto [MaskLo, MaskHi] = DAG.SplitVector(Op.getOperand(2), DL);
15904+
auto [EVLLo, EVLHi] = DAG.SplitEVL(Op.getOperand(3), WidenedVT, DL);
15905+
unsigned UnpkLo = Signed ? AArch64ISD::SUNPKLO : AArch64ISD::UUNPKLO;
15906+
unsigned UnpkHi = Signed ? AArch64ISD::SUNPKHI : AArch64ISD::UUNPKHI;
15907+
SDValue Op0Lo = DAG.getNode(UnpkLo, DL, WidenedVT, Op.getOperand(0));
15908+
SDValue Op1Lo = DAG.getNode(UnpkLo, DL, WidenedVT, Op.getOperand(1));
15909+
SDValue Op0Hi = DAG.getNode(UnpkHi, DL, WidenedVT, Op.getOperand(0));
15910+
SDValue Op1Hi = DAG.getNode(UnpkHi, DL, WidenedVT, Op.getOperand(1));
15911+
SDValue ResultLo =
15912+
DAG.getNode(Op.getOpcode(), DL, WidenedVT, Op0Lo, Op1Lo, MaskLo, EVLLo);
15913+
SDValue ResultHi =
15914+
DAG.getNode(Op.getOpcode(), DL, WidenedVT, Op0Hi, Op1Hi, MaskHi, EVLHi);
15915+
SDValue ResultLoCast = DAG.getNode(AArch64ISD::NVCAST, DL, VT, ResultLo);
15916+
SDValue ResultHiCast = DAG.getNode(AArch64ISD::NVCAST, DL, VT, ResultHi);
15917+
return DAG.getNode(AArch64ISD::UZP1, DL, VT, ResultLoCast, ResultHiCast);
15918+
}
15919+
1587315920
bool AArch64TargetLowering::shouldExpandBuildVectorWithShuffles(
1587415921
EVT VT, unsigned DefinedValues) const {
1587515922
if (!Subtarget->isNeonAvailable())

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,7 @@ class AArch64TargetLowering : public TargetLowering {
706706
SDValue LowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;
707707
SDValue LowerGET_ACTIVE_LANE_MASK(SDValue Op, SelectionDAG &DAG) const;
708708
SDValue LowerDIV(SDValue Op, SelectionDAG &DAG) const;
709+
SDValue LowerVP_DIV(SDValue Op, SelectionDAG &DAG) const;
709710
SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) const;
710711
SDValue LowerVectorSRA_SRL_SHL(SDValue Op, SelectionDAG &DAG) const;
711712
SDValue LowerShiftParts(SDValue Op, SelectionDAG &DAG) const;

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -700,8 +700,8 @@ let Predicates = [HasSVE_or_SME] in {
700700
defm SDIVR_ZPmZ : sve_int_bin_pred_arit_2_div<0b110, "sdivr", "SDIVR_ZPZZ", int_aarch64_sve_sdivr, DestructiveBinaryCommWithRev, "SDIV_ZPmZ", /*isReverseInstr*/ 1>;
701701
defm UDIVR_ZPmZ : sve_int_bin_pred_arit_2_div<0b111, "udivr", "UDIVR_ZPZZ", int_aarch64_sve_udivr, DestructiveBinaryCommWithRev, "UDIV_ZPmZ", /*isReverseInstr*/ 1>;
702702

703-
defm SDIV_ZPZZ : sve_int_bin_pred_sd<AArch64sdiv_p>;
704-
defm UDIV_ZPZZ : sve_int_bin_pred_sd<AArch64udiv_p>;
703+
defm SDIV_ZPZZ : sve_int_bin_pred_sd<AArch64sdiv_p, vp_sdiv>;
704+
defm UDIV_ZPZZ : sve_int_bin_pred_sd<AArch64udiv_p, vp_udiv>;
705705

706706
defm SDOT_ZZZ : sve_intx_dot<0b0, "sdot", AArch64sdot>;
707707
defm UDOT_ZZZ : sve_intx_dot<0b1, "udot", AArch64udot>;

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,17 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
156156

157157
bool isVScaleKnownToBeAPowerOfTwo() const override { return true; }
158158

159+
TargetTransformInfo::VPLegalization
160+
getVPLegalizationStrategy(const VPIntrinsic &PI) const override {
161+
using VPLegalization = TargetTransformInfo::VPLegalization;
162+
switch (PI.getIntrinsicID()) {
163+
case Intrinsic::vp_sdiv:
164+
case Intrinsic::vp_udiv:
165+
return VPLegalization(VPLegalization::Discard, VPLegalization::Legal);
166+
}
167+
return BaseT::getVPLegalizationStrategy(PI);
168+
}
169+
159170
bool shouldMaximizeVectorBandwidth(
160171
TargetTransformInfo::RegisterKind K) const override;
161172

llvm/lib/Target/AArch64/SVEInstrFormats.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9788,12 +9788,17 @@ multiclass sve_int_bin_pred_bhsd<SDPatternOperator op> {
97889788
}
97899789

97909790
// As sve_int_bin_pred but when only i32 and i64 vector types are required.
9791-
multiclass sve_int_bin_pred_sd<SDPatternOperator op> {
9791+
multiclass sve_int_bin_pred_sd<SDPatternOperator op, SDPatternOperator vp_op> {
97929792
def _S_UNDEF : PredTwoOpPseudo<NAME # _S, ZPR32, FalseLanesUndef>;
97939793
def _D_UNDEF : PredTwoOpPseudo<NAME # _D, ZPR64, FalseLanesUndef>;
97949794

97959795
def : SVE_3_Op_Pat<nxv4i32, op, nxv4i1, nxv4i32, nxv4i32, !cast<Pseudo>(NAME # _S_UNDEF)>;
97969796
def : SVE_3_Op_Pat<nxv2i64, op, nxv2i1, nxv2i64, nxv2i64, !cast<Pseudo>(NAME # _D_UNDEF)>;
9797+
9798+
def : Pat<(nxv4i32 (vp_op nxv4i32:$lhs, nxv4i32:$rhs, nxv4i1:$pred, (i32 srcvalue))),
9799+
(!cast<Pseudo>(NAME # _S_UNDEF) $pred, $lhs, $rhs)>;
9800+
def : Pat<(nxv2i64 (vp_op nxv2i64:$lhs, nxv2i64:$rhs, nxv2i1:$pred, (i32 srcvalue))),
9801+
(!cast<Pseudo>(NAME # _D_UNDEF) $pred, $lhs, $rhs)>;
97979802
}
97989803

97999804
// Predicated pseudo integer two operand instructions. Second operand is an
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -mtriple=aarch64 -mattr=+sve < %s | FileCheck %s
3+
4+
define <vscale x 2 x i64> @sdiv_evl_max(<vscale x 2 x i64> %x, <vscale x 2 x i64> %y, <vscale x 2 x i1> %mask) {
5+
; CHECK-LABEL: sdiv_evl_max:
6+
; CHECK: // %bb.0:
7+
; CHECK-NEXT: sdiv z0.d, p0/m, z0.d, z1.d
8+
; CHECK-NEXT: ret
9+
%vscale = call i32 @llvm.vscale()
10+
%evl = mul i32 %vscale, 2
11+
%z = call <vscale x 2 x i64> @llvm.vp.sdiv(<vscale x 2 x i64> %x, <vscale x 2 x i64> %y, <vscale x 2 x i1> %mask, i32 %evl)
12+
ret <vscale x 2 x i64> %z
13+
}
14+
15+
define <vscale x 2 x i64> @sdiv_nxv2i64(<vscale x 2 x i64> %x, <vscale x 2 x i64> %y, <vscale x 2 x i1> %mask, i32 %evl) {
16+
; CHECK-LABEL: sdiv_nxv2i64:
17+
; CHECK: // %bb.0:
18+
; CHECK-NEXT: whilelo p1.d, wzr, w0
19+
; CHECK-NEXT: and p0.b, p1/z, p1.b, p0.b
20+
; CHECK-NEXT: sdiv z0.d, p0/m, z0.d, z1.d
21+
; CHECK-NEXT: ret
22+
%z = call <vscale x 2 x i64> @llvm.vp.sdiv(<vscale x 2 x i64> %x, <vscale x 2 x i64> %y, <vscale x 2 x i1> %mask, i32 %evl)
23+
ret <vscale x 2 x i64> %z
24+
}
25+
26+
define <vscale x 4 x i32> @sdiv_nxv4i32(<vscale x 4 x i32> %x, <vscale x 4 x i32> %y, <vscale x 4 x i1> %mask, i32 %evl) {
27+
; CHECK-LABEL: sdiv_nxv4i32:
28+
; CHECK: // %bb.0:
29+
; CHECK-NEXT: whilelo p1.s, wzr, w0
30+
; CHECK-NEXT: and p0.b, p1/z, p1.b, p0.b
31+
; CHECK-NEXT: sdiv z0.s, p0/m, z0.s, z1.s
32+
; CHECK-NEXT: ret
33+
%z = call <vscale x 4 x i32> @llvm.vp.sdiv(<vscale x 4 x i32> %x, <vscale x 4 x i32> %y, <vscale x 4 x i1> %mask, i32 %evl)
34+
ret <vscale x 4 x i32> %z
35+
}
36+
37+
define <vscale x 8 x i16> @sdiv_nxv8i16(<vscale x 8 x i16> %x, <vscale x 8 x i16> %y, <vscale x 8 x i1> %mask, i32 %evl) {
38+
; CHECK-LABEL: sdiv_nxv8i16:
39+
; CHECK: // %bb.0:
40+
; CHECK-NEXT: whilelo p1.h, wzr, w0
41+
; CHECK-NEXT: sunpkhi z2.s, z1.h
42+
; CHECK-NEXT: sunpkhi z3.s, z0.h
43+
; CHECK-NEXT: sunpklo z1.s, z1.h
44+
; CHECK-NEXT: sunpklo z0.s, z0.h
45+
; CHECK-NEXT: and p0.b, p1/z, p1.b, p0.b
46+
; CHECK-NEXT: punpkhi p1.h, p0.b
47+
; CHECK-NEXT: punpklo p0.h, p0.b
48+
; CHECK-NEXT: sdivr z2.s, p1/m, z2.s, z3.s
49+
; CHECK-NEXT: sdiv z0.s, p0/m, z0.s, z1.s
50+
; CHECK-NEXT: uzp1 z0.h, z0.h, z2.h
51+
; CHECK-NEXT: ret
52+
%z = call <vscale x 8 x i16> @llvm.vp.sdiv(<vscale x 8 x i16> %x, <vscale x 8 x i16> %y, <vscale x 8 x i1> %mask, i32 %evl)
53+
ret <vscale x 8 x i16> %z
54+
}
55+
56+
define <vscale x 8 x i16> @sdiv_nxv16i8(<vscale x 8 x i16> %x, <vscale x 8 x i16> %y, <vscale x 8 x i1> %mask, i32 %evl) {
57+
; CHECK-LABEL: sdiv_nxv16i8:
58+
; CHECK: // %bb.0:
59+
; CHECK-NEXT: whilelo p1.h, wzr, w0
60+
; CHECK-NEXT: sunpkhi z2.s, z1.h
61+
; CHECK-NEXT: sunpkhi z3.s, z0.h
62+
; CHECK-NEXT: sunpklo z1.s, z1.h
63+
; CHECK-NEXT: sunpklo z0.s, z0.h
64+
; CHECK-NEXT: and p0.b, p1/z, p1.b, p0.b
65+
; CHECK-NEXT: punpkhi p1.h, p0.b
66+
; CHECK-NEXT: punpklo p0.h, p0.b
67+
; CHECK-NEXT: sdivr z2.s, p1/m, z2.s, z3.s
68+
; CHECK-NEXT: sdiv z0.s, p0/m, z0.s, z1.s
69+
; CHECK-NEXT: uzp1 z0.h, z0.h, z2.h
70+
; CHECK-NEXT: ret
71+
%z = call <vscale x 8 x i16> @llvm.vp.sdiv(<vscale x 8 x i16> %x, <vscale x 8 x i16> %y, <vscale x 8 x i1> %mask, i32 %evl)
72+
ret <vscale x 8 x i16> %z
73+
}
74+
75+
define <vscale x 2 x i64> @udiv_evl_max(<vscale x 2 x i64> %x, <vscale x 2 x i64> %y, <vscale x 2 x i1> %mask) {
76+
; CHECK-LABEL: udiv_evl_max:
77+
; CHECK: // %bb.0:
78+
; CHECK-NEXT: udiv z0.d, p0/m, z0.d, z1.d
79+
; CHECK-NEXT: ret
80+
%vscale = call i32 @llvm.vscale()
81+
%evl = mul i32 %vscale, 2
82+
%z = call <vscale x 2 x i64> @llvm.vp.udiv(<vscale x 2 x i64> %x, <vscale x 2 x i64> %y, <vscale x 2 x i1> %mask, i32 %evl)
83+
ret <vscale x 2 x i64> %z
84+
}
85+
86+
define <vscale x 2 x i64> @udiv_nxv2i64(<vscale x 2 x i64> %x, <vscale x 2 x i64> %y, <vscale x 2 x i1> %mask, i32 %evl) {
87+
; CHECK-LABEL: udiv_nxv2i64:
88+
; CHECK: // %bb.0:
89+
; CHECK-NEXT: whilelo p1.d, wzr, w0
90+
; CHECK-NEXT: and p0.b, p1/z, p1.b, p0.b
91+
; CHECK-NEXT: udiv z0.d, p0/m, z0.d, z1.d
92+
; CHECK-NEXT: ret
93+
%z = call <vscale x 2 x i64> @llvm.vp.udiv(<vscale x 2 x i64> %x, <vscale x 2 x i64> %y, <vscale x 2 x i1> %mask, i32 %evl)
94+
ret <vscale x 2 x i64> %z
95+
}
96+
97+
define <vscale x 4 x i32> @udiv_nxv4i32(<vscale x 4 x i32> %x, <vscale x 4 x i32> %y, <vscale x 4 x i1> %mask, i32 %evl) {
98+
; CHECK-LABEL: udiv_nxv4i32:
99+
; CHECK: // %bb.0:
100+
; CHECK-NEXT: whilelo p1.s, wzr, w0
101+
; CHECK-NEXT: and p0.b, p1/z, p1.b, p0.b
102+
; CHECK-NEXT: udiv z0.s, p0/m, z0.s, z1.s
103+
; CHECK-NEXT: ret
104+
%z = call <vscale x 4 x i32> @llvm.vp.udiv(<vscale x 4 x i32> %x, <vscale x 4 x i32> %y, <vscale x 4 x i1> %mask, i32 %evl)
105+
ret <vscale x 4 x i32> %z
106+
}
107+
108+
define <vscale x 8 x i16> @udiv_nxv8i16(<vscale x 8 x i16> %x, <vscale x 8 x i16> %y, <vscale x 8 x i1> %mask, i32 %evl) {
109+
; CHECK-LABEL: udiv_nxv8i16:
110+
; CHECK: // %bb.0:
111+
; CHECK-NEXT: whilelo p1.h, wzr, w0
112+
; CHECK-NEXT: uunpkhi z2.s, z1.h
113+
; CHECK-NEXT: uunpkhi z3.s, z0.h
114+
; CHECK-NEXT: uunpklo z1.s, z1.h
115+
; CHECK-NEXT: uunpklo z0.s, z0.h
116+
; CHECK-NEXT: and p0.b, p1/z, p1.b, p0.b
117+
; CHECK-NEXT: punpkhi p1.h, p0.b
118+
; CHECK-NEXT: punpklo p0.h, p0.b
119+
; CHECK-NEXT: udivr z2.s, p1/m, z2.s, z3.s
120+
; CHECK-NEXT: udiv z0.s, p0/m, z0.s, z1.s
121+
; CHECK-NEXT: uzp1 z0.h, z0.h, z2.h
122+
; CHECK-NEXT: ret
123+
%z = call <vscale x 8 x i16> @llvm.vp.udiv(<vscale x 8 x i16> %x, <vscale x 8 x i16> %y, <vscale x 8 x i1> %mask, i32 %evl)
124+
ret <vscale x 8 x i16> %z
125+
}
126+
127+
define <vscale x 8 x i16> @udiv_nxv16i8(<vscale x 8 x i16> %x, <vscale x 8 x i16> %y, <vscale x 8 x i1> %mask, i32 %evl) {
128+
; CHECK-LABEL: udiv_nxv16i8:
129+
; CHECK: // %bb.0:
130+
; CHECK-NEXT: whilelo p1.h, wzr, w0
131+
; CHECK-NEXT: uunpkhi z2.s, z1.h
132+
; CHECK-NEXT: uunpkhi z3.s, z0.h
133+
; CHECK-NEXT: uunpklo z1.s, z1.h
134+
; CHECK-NEXT: uunpklo z0.s, z0.h
135+
; CHECK-NEXT: and p0.b, p1/z, p1.b, p0.b
136+
; CHECK-NEXT: punpkhi p1.h, p0.b
137+
; CHECK-NEXT: punpklo p0.h, p0.b
138+
; CHECK-NEXT: udivr z2.s, p1/m, z2.s, z3.s
139+
; CHECK-NEXT: udiv z0.s, p0/m, z0.s, z1.s
140+
; CHECK-NEXT: uzp1 z0.h, z0.h, z2.h
141+
; CHECK-NEXT: ret
142+
%z = call <vscale x 8 x i16> @llvm.vp.udiv(<vscale x 8 x i16> %x, <vscale x 8 x i16> %y, <vscale x 8 x i1> %mask, i32 %evl)
143+
ret <vscale x 8 x i16> %z
144+
}

0 commit comments

Comments
 (0)