Skip to content

Commit ab89cfd

Browse files
authored
[RISCV] Use vwsll.vi/vx + vwaddu.wv to lower vector.interleave when Zvbb enabled. (#67521)
The replacement could avoid an assignment to GPR when the type is vector of i8/i16 and vwmaccu.wv which may have higher cost than vwsll.vi/vx.
1 parent 0f339e6 commit ab89cfd

File tree

6 files changed

+411
-20
lines changed

6 files changed

+411
-20
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4272,24 +4272,37 @@ static SDValue getWideningInterleave(SDValue EvenV, SDValue OddV,
42724272
auto [Mask, VL] = getDefaultVLOps(VecVT, VecContainerVT, DL, DAG, Subtarget);
42734273
SDValue Passthru = DAG.getUNDEF(WideContainerVT);
42744274

4275-
// Widen EvenV and OddV with 0s and add one copy of OddV to EvenV with
4276-
// vwaddu.vv
4277-
SDValue Interleaved = DAG.getNode(RISCVISD::VWADDU_VL, DL, WideContainerVT,
4278-
EvenV, OddV, Passthru, Mask, VL);
4279-
4280-
// Then get OddV * by 2^(VecVT.getScalarSizeInBits() - 1)
4281-
SDValue AllOnesVec = DAG.getSplatVector(
4282-
VecContainerVT, DL, DAG.getAllOnesConstant(DL, Subtarget.getXLenVT()));
4283-
SDValue OddsMul = DAG.getNode(RISCVISD::VWMULU_VL, DL, WideContainerVT, OddV,
4284-
AllOnesVec, Passthru, Mask, VL);
4285-
4286-
// Add the two together so we get
4287-
// (OddV * 0xff...ff) + (OddV + EvenV)
4288-
// = (OddV * 0x100...00) + EvenV
4289-
// = (OddV << VecVT.getScalarSizeInBits()) + EvenV
4290-
// Note the ADD_VL and VLMULU_VL should get selected as vwmaccu.vx
4291-
Interleaved = DAG.getNode(RISCVISD::ADD_VL, DL, WideContainerVT, Interleaved,
4292-
OddsMul, Passthru, Mask, VL);
4275+
SDValue Interleaved;
4276+
if (Subtarget.hasStdExtZvbb()) {
4277+
// Interleaved = (OddV << VecVT.getScalarSizeInBits()) + EvenV.
4278+
SDValue OffsetVec =
4279+
DAG.getSplatVector(VecContainerVT, DL,
4280+
DAG.getConstant(VecVT.getScalarSizeInBits(), DL,
4281+
Subtarget.getXLenVT()));
4282+
Interleaved = DAG.getNode(RISCVISD::VWSLL_VL, DL, WideContainerVT, OddV,
4283+
OffsetVec, Passthru, Mask, VL);
4284+
Interleaved = DAG.getNode(RISCVISD::VWADDU_W_VL, DL, WideContainerVT,
4285+
Interleaved, EvenV, Passthru, Mask, VL);
4286+
} else {
4287+
// Widen EvenV and OddV with 0s and add one copy of OddV to EvenV with
4288+
// vwaddu.vv
4289+
Interleaved = DAG.getNode(RISCVISD::VWADDU_VL, DL, WideContainerVT, EvenV,
4290+
OddV, Passthru, Mask, VL);
4291+
4292+
// Then get OddV * by 2^(VecVT.getScalarSizeInBits() - 1)
4293+
SDValue AllOnesVec = DAG.getSplatVector(
4294+
VecContainerVT, DL, DAG.getAllOnesConstant(DL, Subtarget.getXLenVT()));
4295+
SDValue OddsMul = DAG.getNode(RISCVISD::VWMULU_VL, DL, WideContainerVT,
4296+
OddV, AllOnesVec, Passthru, Mask, VL);
4297+
4298+
// Add the two together so we get
4299+
// (OddV * 0xff...ff) + (OddV + EvenV)
4300+
// = (OddV * 0x100...00) + EvenV
4301+
// = (OddV << VecVT.getScalarSizeInBits()) + EvenV
4302+
// Note the ADD_VL and VLMULU_VL should get selected as vwmaccu.vx
4303+
Interleaved = DAG.getNode(RISCVISD::ADD_VL, DL, WideContainerVT,
4304+
Interleaved, OddsMul, Passthru, Mask, VL);
4305+
}
42934306

42944307
// Bitcast from <vscale x n * ty*2> to <vscale x 2*n x ty>
42954308
MVT ResultContainerVT = MVT::getVectorVT(
@@ -5323,7 +5336,7 @@ static bool hasMergeOp(unsigned Opcode) {
53235336
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
53245337
"not a RISC-V target specific op");
53255338
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
5326-
124 &&
5339+
125 &&
53275340
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
53285341
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
53295342
21 &&
@@ -5347,7 +5360,7 @@ static bool hasMaskOp(unsigned Opcode) {
53475360
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
53485361
"not a RISC-V target specific op");
53495362
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
5350-
124 &&
5363+
125 &&
53515364
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
53525365
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
53535366
21 &&
@@ -17579,6 +17592,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
1757917592
NODE_NAME_CASE(VWADDU_W_VL)
1758017593
NODE_NAME_CASE(VWSUB_W_VL)
1758117594
NODE_NAME_CASE(VWSUBU_W_VL)
17595+
NODE_NAME_CASE(VWSLL_VL)
1758217596
NODE_NAME_CASE(VFWMUL_VL)
1758317597
NODE_NAME_CASE(VFWADD_VL)
1758417598
NODE_NAME_CASE(VFWSUB_VL)

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ enum NodeType : unsigned {
309309
VWADDU_W_VL,
310310
VWSUB_W_VL,
311311
VWSUBU_W_VL,
312+
VWSLL_VL,
312313

313314
VFWMUL_VL,
314315
VFWADD_VL,

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ def riscv_vwadd_vl : SDNode<"RISCVISD::VWADD_VL", SDT_RISCVVWIntBinOp_VL, [S
409409
def riscv_vwaddu_vl : SDNode<"RISCVISD::VWADDU_VL", SDT_RISCVVWIntBinOp_VL, [SDNPCommutative]>;
410410
def riscv_vwsub_vl : SDNode<"RISCVISD::VWSUB_VL", SDT_RISCVVWIntBinOp_VL, []>;
411411
def riscv_vwsubu_vl : SDNode<"RISCVISD::VWSUBU_VL", SDT_RISCVVWIntBinOp_VL, []>;
412+
def riscv_vwsll_vl : SDNode<"RISCVISD::VWSLL_VL", SDT_RISCVVWIntBinOp_VL, []>;
412413

413414
def SDT_RISCVVWIntTernOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisInt<0>,
414415
SDTCisInt<1>,

llvm/lib/Target/RISCV/RISCVInstrInfoZvk.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,33 @@ foreach vtiToWti = AllWidenableIntVectors in {
641641
(!cast<Instruction>("PseudoVWSLL_VI_"#vti.LMul.MX#"_MASK")
642642
wti.RegClass:$merge, vti.RegClass:$rs2, uimm5:$rs1,
643643
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
644+
645+
def : Pat<(riscv_vwsll_vl
646+
(vti.Vector vti.RegClass:$rs2),
647+
(vti.Vector vti.RegClass:$rs1),
648+
(wti.Vector wti.RegClass:$merge),
649+
(vti.Mask V0), VLOpFrag),
650+
(!cast<Instruction>("PseudoVWSLL_VV_"#vti.LMul.MX#"_MASK")
651+
wti.RegClass:$merge, vti.RegClass:$rs2, vti.RegClass:$rs1,
652+
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
653+
654+
def : Pat<(riscv_vwsll_vl
655+
(vti.Vector vti.RegClass:$rs2),
656+
(vti.Vector (Low8BitsSplatPat (XLenVT GPR:$rs1))),
657+
(wti.Vector wti.RegClass:$merge),
658+
(vti.Mask V0), VLOpFrag),
659+
(!cast<Instruction>("PseudoVWSLL_VX_"#vti.LMul.MX#"_MASK")
660+
wti.RegClass:$merge, vti.RegClass:$rs2, GPR:$rs1,
661+
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
662+
663+
def : Pat<(riscv_vwsll_vl
664+
(vti.Vector vti.RegClass:$rs2),
665+
(vti.Vector (SplatPat_uimm5 uimm5:$rs1)),
666+
(wti.Vector wti.RegClass:$merge),
667+
(vti.Mask V0), VLOpFrag),
668+
(!cast<Instruction>("PseudoVWSLL_VI_"#vti.LMul.MX#"_MASK")
669+
wti.RegClass:$merge, vti.RegClass:$rs2, uimm5:$rs1,
670+
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
644671
}
645672
}
646673

llvm/test/CodeGen/RISCV/rvv/vector-interleave-fixed.ll

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
22
; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zfh,+zvfh | FileCheck -check-prefixes=CHECK,RV32 %s
33
; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zfh,+zvfh | FileCheck -check-prefixes=CHECK,RV64 %s
4+
; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+experimental-zvbb,+zfh,+zvfh | FileCheck %s --check-prefix=ZVBB
5+
; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+experimental-zvbb,+zfh,+zvfh | FileCheck %s --check-prefix=ZVBB
46

57
; Integers
68

@@ -22,6 +24,23 @@ define <32 x i1> @vector_interleave_v32i1_v16i1(<16 x i1> %a, <16 x i1> %b) {
2224
; CHECK-NEXT: vsetvli zero, a0, e8, m2, ta, ma
2325
; CHECK-NEXT: vmsne.vi v0, v12, 0
2426
; CHECK-NEXT: ret
27+
;
28+
; ZVBB-LABEL: vector_interleave_v32i1_v16i1:
29+
; ZVBB: # %bb.0:
30+
; ZVBB-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
31+
; ZVBB-NEXT: vslideup.vi v0, v8, 2
32+
; ZVBB-NEXT: li a0, 32
33+
; ZVBB-NEXT: vsetvli zero, a0, e8, m2, ta, ma
34+
; ZVBB-NEXT: vmv.v.i v8, 0
35+
; ZVBB-NEXT: vmerge.vim v8, v8, 1, v0
36+
; ZVBB-NEXT: vsetivli zero, 16, e8, m2, ta, ma
37+
; ZVBB-NEXT: vslidedown.vi v10, v8, 16
38+
; ZVBB-NEXT: vsetivli zero, 16, e8, m1, ta, ma
39+
; ZVBB-NEXT: vwsll.vi v12, v10, 8
40+
; ZVBB-NEXT: vwaddu.wv v12, v12, v8
41+
; ZVBB-NEXT: vsetvli zero, a0, e8, m2, ta, ma
42+
; ZVBB-NEXT: vmsne.vi v0, v12, 0
43+
; ZVBB-NEXT: ret
2544
%res = call <32 x i1> @llvm.experimental.vector.interleave2.v32i1(<16 x i1> %a, <16 x i1> %b)
2645
ret <32 x i1> %res
2746
}
@@ -35,6 +54,14 @@ define <16 x i16> @vector_interleave_v16i16_v8i16(<8 x i16> %a, <8 x i16> %b) {
3554
; CHECK-NEXT: vwmaccu.vx v10, a0, v9
3655
; CHECK-NEXT: vmv2r.v v8, v10
3756
; CHECK-NEXT: ret
57+
;
58+
; ZVBB-LABEL: vector_interleave_v16i16_v8i16:
59+
; ZVBB: # %bb.0:
60+
; ZVBB-NEXT: vsetivli zero, 8, e16, m1, ta, ma
61+
; ZVBB-NEXT: vwsll.vi v10, v9, 16
62+
; ZVBB-NEXT: vwaddu.wv v10, v10, v8
63+
; ZVBB-NEXT: vmv2r.v v8, v10
64+
; ZVBB-NEXT: ret
3865
%res = call <16 x i16> @llvm.experimental.vector.interleave2.v16i16(<8 x i16> %a, <8 x i16> %b)
3966
ret <16 x i16> %res
4067
}
@@ -48,6 +75,15 @@ define <8 x i32> @vector_interleave_v8i32_v4i32(<4 x i32> %a, <4 x i32> %b) {
4875
; CHECK-NEXT: vwmaccu.vx v10, a0, v9
4976
; CHECK-NEXT: vmv2r.v v8, v10
5077
; CHECK-NEXT: ret
78+
;
79+
; ZVBB-LABEL: vector_interleave_v8i32_v4i32:
80+
; ZVBB: # %bb.0:
81+
; ZVBB-NEXT: li a0, 32
82+
; ZVBB-NEXT: vsetivli zero, 4, e32, m1, ta, ma
83+
; ZVBB-NEXT: vwsll.vx v10, v9, a0
84+
; ZVBB-NEXT: vwaddu.wv v10, v10, v8
85+
; ZVBB-NEXT: vmv2r.v v8, v10
86+
; ZVBB-NEXT: ret
5187
%res = call <8 x i32> @llvm.experimental.vector.interleave2.v8i32(<4 x i32> %a, <4 x i32> %b)
5288
ret <8 x i32> %res
5389
}
@@ -102,6 +138,14 @@ define <4 x half> @vector_interleave_v4f16_v2f16(<2 x half> %a, <2 x half> %b) {
102138
; CHECK-NEXT: vwmaccu.vx v10, a0, v9
103139
; CHECK-NEXT: vmv1r.v v8, v10
104140
; CHECK-NEXT: ret
141+
;
142+
; ZVBB-LABEL: vector_interleave_v4f16_v2f16:
143+
; ZVBB: # %bb.0:
144+
; ZVBB-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
145+
; ZVBB-NEXT: vwsll.vi v10, v9, 16
146+
; ZVBB-NEXT: vwaddu.wv v10, v10, v8
147+
; ZVBB-NEXT: vmv1r.v v8, v10
148+
; ZVBB-NEXT: ret
105149
%res = call <4 x half> @llvm.experimental.vector.interleave2.v4f16(<2 x half> %a, <2 x half> %b)
106150
ret <4 x half> %res
107151
}
@@ -115,6 +159,14 @@ define <8 x half> @vector_interleave_v8f16_v4f16(<4 x half> %a, <4 x half> %b) {
115159
; CHECK-NEXT: vwmaccu.vx v10, a0, v9
116160
; CHECK-NEXT: vmv1r.v v8, v10
117161
; CHECK-NEXT: ret
162+
;
163+
; ZVBB-LABEL: vector_interleave_v8f16_v4f16:
164+
; ZVBB: # %bb.0:
165+
; ZVBB-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
166+
; ZVBB-NEXT: vwsll.vi v10, v9, 16
167+
; ZVBB-NEXT: vwaddu.wv v10, v10, v8
168+
; ZVBB-NEXT: vmv1r.v v8, v10
169+
; ZVBB-NEXT: ret
118170
%res = call <8 x half> @llvm.experimental.vector.interleave2.v8f16(<4 x half> %a, <4 x half> %b)
119171
ret <8 x half> %res
120172
}
@@ -128,6 +180,15 @@ define <4 x float> @vector_interleave_v4f32_v2f32(<2 x float> %a, <2 x float> %b
128180
; CHECK-NEXT: vwmaccu.vx v10, a0, v9
129181
; CHECK-NEXT: vmv1r.v v8, v10
130182
; CHECK-NEXT: ret
183+
;
184+
; ZVBB-LABEL: vector_interleave_v4f32_v2f32:
185+
; ZVBB: # %bb.0:
186+
; ZVBB-NEXT: li a0, 32
187+
; ZVBB-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
188+
; ZVBB-NEXT: vwsll.vx v10, v9, a0
189+
; ZVBB-NEXT: vwaddu.wv v10, v10, v8
190+
; ZVBB-NEXT: vmv1r.v v8, v10
191+
; ZVBB-NEXT: ret
131192
%res = call <4 x float> @llvm.experimental.vector.interleave2.v4f32(<2 x float> %a, <2 x float> %b)
132193
ret <4 x float> %res
133194
}
@@ -141,6 +202,14 @@ define <16 x half> @vector_interleave_v16f16_v8f16(<8 x half> %a, <8 x half> %b)
141202
; CHECK-NEXT: vwmaccu.vx v10, a0, v9
142203
; CHECK-NEXT: vmv2r.v v8, v10
143204
; CHECK-NEXT: ret
205+
;
206+
; ZVBB-LABEL: vector_interleave_v16f16_v8f16:
207+
; ZVBB: # %bb.0:
208+
; ZVBB-NEXT: vsetivli zero, 8, e16, m1, ta, ma
209+
; ZVBB-NEXT: vwsll.vi v10, v9, 16
210+
; ZVBB-NEXT: vwaddu.wv v10, v10, v8
211+
; ZVBB-NEXT: vmv2r.v v8, v10
212+
; ZVBB-NEXT: ret
144213
%res = call <16 x half> @llvm.experimental.vector.interleave2.v16f16(<8 x half> %a, <8 x half> %b)
145214
ret <16 x half> %res
146215
}
@@ -154,6 +223,15 @@ define <8 x float> @vector_interleave_v8f32_v4f32(<4 x float> %a, <4 x float> %b
154223
; CHECK-NEXT: vwmaccu.vx v10, a0, v9
155224
; CHECK-NEXT: vmv2r.v v8, v10
156225
; CHECK-NEXT: ret
226+
;
227+
; ZVBB-LABEL: vector_interleave_v8f32_v4f32:
228+
; ZVBB: # %bb.0:
229+
; ZVBB-NEXT: li a0, 32
230+
; ZVBB-NEXT: vsetivli zero, 4, e32, m1, ta, ma
231+
; ZVBB-NEXT: vwsll.vx v10, v9, a0
232+
; ZVBB-NEXT: vwaddu.wv v10, v10, v8
233+
; ZVBB-NEXT: vmv2r.v v8, v10
234+
; ZVBB-NEXT: ret
157235
%res = call <8 x float> @llvm.experimental.vector.interleave2.v8f32(<4 x float> %a, <4 x float> %b)
158236
ret <8 x float> %res
159237
}

0 commit comments

Comments
 (0)