Skip to content

Commit f3116a5

Browse files
committed
[RISCV] Initial codegen support for the XRivosVizip extension
This implements initial code generation support for the xrivosvizip extension. A couple of things to note: * The zipeven/zipodd matchers were recently rewritten to better match upstream style, so careful review there would be appreciated. * The zipeven/zipodd cases don't yet support type coercion. This will be done in a future patch. * I subsetted the unzip2a/b support in a way which makes it functional, but far from optimal. A further change will reintroduce some of the complexity once it's easy to test and show incremental change.
1 parent 93b8ef4 commit f3116a5

File tree

6 files changed

+782
-80
lines changed

6 files changed

+782
-80
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 121 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4553,8 +4553,10 @@ static SDValue getSingleShuffleSrc(MVT VT, SDValue V1, SDValue V2) {
45534553
/// way through the source.
45544554
static bool isInterleaveShuffle(ArrayRef<int> Mask, MVT VT, int &EvenSrc,
45554555
int &OddSrc, const RISCVSubtarget &Subtarget) {
4556-
// We need to be able to widen elements to the next larger integer type.
4557-
if (VT.getScalarSizeInBits() >= Subtarget.getELen())
4556+
// We need to be able to widen elements to the next larger integer type or
4557+
// use the zip2a instruction at e64.
4558+
if (VT.getScalarSizeInBits() >= Subtarget.getELen() &&
4559+
!Subtarget.hasVendorXRivosVizip())
45584560
return false;
45594561

45604562
int Size = Mask.size();
@@ -4611,6 +4613,43 @@ static bool isElementRotate(std::array<std::pair<int, int>, 2> &SrcInfo,
46114613
SrcInfo[1].second - SrcInfo[0].second == (int)NumElts;
46124614
}
46134615

4616+
static bool isAlternating(std::array<std::pair<int, int>, 2> &SrcInfo,
4617+
ArrayRef<int> Mask, bool &Polarity) {
4618+
int NumElts = Mask.size();
4619+
bool NonUndefFound = false;
4620+
for (unsigned i = 0; i != Mask.size(); ++i) {
4621+
int M = Mask[i];
4622+
if (M < 0)
4623+
continue;
4624+
int Src = M >= (int)NumElts;
4625+
int Diff = (int)i - (M % NumElts);
4626+
bool C = Src == SrcInfo[1].first && Diff == SrcInfo[1].second;
4627+
if (!NonUndefFound) {
4628+
NonUndefFound = true;
4629+
Polarity = (C == i % 2);
4630+
continue;
4631+
}
4632+
if ((Polarity && C != i % 2) || (!Polarity && C == i % 2))
4633+
return false;
4634+
}
4635+
return true;
4636+
}
4637+
4638+
static bool isZipEven(std::array<std::pair<int, int>, 2> &SrcInfo,
4639+
ArrayRef<int> Mask) {
4640+
bool Polarity;
4641+
return SrcInfo[0].second == 0 && SrcInfo[1].second == 1 &&
4642+
isAlternating(SrcInfo, Mask, Polarity) && Polarity;
4643+
;
4644+
}
4645+
4646+
static bool isZipOdd(std::array<std::pair<int, int>, 2> &SrcInfo,
4647+
ArrayRef<int> Mask) {
4648+
bool Polarity;
4649+
return SrcInfo[0].second == 0 && SrcInfo[1].second == -1 &&
4650+
isAlternating(SrcInfo, Mask, Polarity) && !Polarity;
4651+
}
4652+
46144653
// Lower a deinterleave shuffle to SRL and TRUNC. Factor must be
46154654
// 2, 4, 8 and the integer type Factor-times larger than VT's
46164655
// element type must be a legal element type.
@@ -4870,6 +4909,36 @@ static bool isSpreadMask(ArrayRef<int> Mask, unsigned Factor, unsigned &Index) {
48704909
return true;
48714910
}
48724911

4912+
static SDValue lowerVIZIP(unsigned Opc, SDValue Op0, SDValue Op1,
4913+
const SDLoc &DL, SelectionDAG &DAG,
4914+
const RISCVSubtarget &Subtarget) {
4915+
assert(RISCVISD::RI_VZIPEVEN_VL == Opc || RISCVISD::RI_VZIPODD_VL == Opc ||
4916+
RISCVISD::RI_VZIP2A_VL == Opc || RISCVISD::RI_VZIP2B_VL == Opc ||
4917+
RISCVISD::RI_VUNZIP2A_VL == Opc || RISCVISD::RI_VUNZIP2B_VL == Opc);
4918+
assert(Op0.getSimpleValueType() == Op1.getSimpleValueType());
4919+
4920+
MVT VT = Op0.getSimpleValueType();
4921+
MVT IntVT = VT.changeVectorElementTypeToInteger();
4922+
Op0 = DAG.getBitcast(IntVT, Op0);
4923+
Op1 = DAG.getBitcast(IntVT, Op1);
4924+
4925+
MVT ContainerVT = IntVT;
4926+
if (VT.isFixedLengthVector()) {
4927+
ContainerVT = getContainerForFixedLengthVector(DAG, IntVT, Subtarget);
4928+
Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
4929+
Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
4930+
}
4931+
4932+
auto [Mask, VL] = getDefaultVLOps(IntVT, ContainerVT, DL, DAG, Subtarget);
4933+
SDValue Passthru = DAG.getUNDEF(ContainerVT);
4934+
SDValue Res =
4935+
DAG.getNode(Opc, DL, ContainerVT, Op0, Op1, Passthru, Mask, VL);
4936+
if (IntVT.isFixedLengthVector())
4937+
Res = convertFromScalableVector(IntVT, Res, DAG, Subtarget);
4938+
Res = DAG.getBitcast(VT, Res);
4939+
return Res;
4940+
}
4941+
48734942
// Given a vector a, b, c, d return a vector Factor times longer
48744943
// with Factor-1 undef's between elements. Ex:
48754944
// a, undef, b, undef, c, undef, d, undef (Factor=2, Index=0)
@@ -5384,6 +5453,7 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
53845453
SDLoc DL(Op);
53855454
MVT XLenVT = Subtarget.getXLenVT();
53865455
MVT VT = Op.getSimpleValueType();
5456+
EVT ElemVT = VT.getVectorElementType();
53875457
unsigned NumElts = VT.getVectorNumElements();
53885458
ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op.getNode());
53895459

@@ -5556,6 +5626,25 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
55565626
}
55575627
}
55585628

5629+
// If this is an e64 deinterleave(2) (possibly with two distinct sources)
5630+
// match to the vunzip2a/vunzip2b.
5631+
unsigned Index = 0;
5632+
if (Subtarget.hasVendorXRivosVizip() && ElemVT == MVT::i64 &&
5633+
ShuffleVectorInst::isDeInterleaveMaskOfFactor(Mask, 2, Index) &&
5634+
1 < count_if(Mask, [](int Idx) { return Idx != -1; })) {
5635+
MVT HalfVT = VT.getHalfNumVectorElementsVT();
5636+
unsigned Opc = Index == 0 ?
5637+
RISCVISD::RI_VUNZIP2A_VL : RISCVISD::RI_VUNZIP2B_VL;
5638+
V1 = lowerVIZIP(Opc, V1, DAG.getUNDEF(VT), DL, DAG, Subtarget);
5639+
V2 = lowerVIZIP(Opc, V2, DAG.getUNDEF(VT), DL, DAG, Subtarget);
5640+
5641+
V1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, V1,
5642+
DAG.getVectorIdxConstant(0, DL));
5643+
V2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, V2,
5644+
DAG.getVectorIdxConstant(0, DL));
5645+
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, V1, V2);
5646+
}
5647+
55595648
if (SDValue V =
55605649
lowerVECTOR_SHUFFLEAsVSlideup(DL, VT, V1, V2, Mask, Subtarget, DAG))
55615650
return V;
@@ -5596,6 +5685,15 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
55965685
DAG.getVectorIdxConstant(OddSrc % Size, DL));
55975686
}
55985687

5688+
// Prefer vzip2a if available.
5689+
// TODO: Extend to matching zip2b if EvenSrc and OddSrc allow.
5690+
if (Subtarget.hasVendorXRivosVizip()) {
5691+
EvenV = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT),
5692+
EvenV, DAG.getVectorIdxConstant(0, DL));
5693+
OddV = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT),
5694+
OddV, DAG.getVectorIdxConstant(0, DL));
5695+
return lowerVIZIP(RISCVISD::RI_VZIP2A_VL, EvenV, OddV, DL, DAG, Subtarget);
5696+
}
55995697
return getWideningInterleave(EvenV, OddV, DL, DAG, Subtarget);
56005698
}
56015699

@@ -5647,6 +5745,17 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
56475745
return convertFromScalableVector(VT, Res, DAG, Subtarget);
56485746
}
56495747

5748+
if (Subtarget.hasVendorXRivosVizip() && isZipEven(SrcInfo, Mask)) {
5749+
SDValue Src1 = SrcInfo[0].first == 0 ? V1 : V2;
5750+
SDValue Src2 = SrcInfo[1].first == 0 ? V1 : V2;
5751+
return lowerVIZIP(RISCVISD::RI_VZIPEVEN_VL, Src1, Src2, DL, DAG, Subtarget);
5752+
}
5753+
if (Subtarget.hasVendorXRivosVizip() && isZipOdd(SrcInfo, Mask)) {
5754+
SDValue Src1 = SrcInfo[1].first == 0 ? V1 : V2;
5755+
SDValue Src2 = SrcInfo[0].first == 0 ? V1 : V2;
5756+
return lowerVIZIP(RISCVISD::RI_VZIPODD_VL, Src1, Src2, DL, DAG, Subtarget);
5757+
}
5758+
56505759
// Build the mask. Note that vslideup unconditionally preserves elements
56515760
// below the slide amount in the destination, and thus those elements are
56525761
// undefined in the mask. If the mask ends up all true (or undef), it
@@ -6710,7 +6819,7 @@ static bool hasPassthruOp(unsigned Opcode) {
67106819
Opcode <= RISCVISD::LAST_STRICTFP_OPCODE &&
67116820
"not a RISC-V target specific op");
67126821
static_assert(
6713-
RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 127 &&
6822+
RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 133 &&
67146823
RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 &&
67156824
"adding target specific op should update this function");
67166825
if (Opcode >= RISCVISD::ADD_VL && Opcode <= RISCVISD::VFMAX_VL)
@@ -6734,12 +6843,13 @@ static bool hasMaskOp(unsigned Opcode) {
67346843
Opcode <= RISCVISD::LAST_STRICTFP_OPCODE &&
67356844
"not a RISC-V target specific op");
67366845
static_assert(
6737-
RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 127 &&
6846+
RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 133 &&
67386847
RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 &&
67396848
"adding target specific op should update this function");
67406849
if (Opcode >= RISCVISD::TRUNCATE_VECTOR_VL && Opcode <= RISCVISD::SETCC_VL)
67416850
return true;
6742-
if (Opcode >= RISCVISD::VRGATHER_VX_VL && Opcode <= RISCVISD::VFIRST_VL)
6851+
if (Opcode >= RISCVISD::VRGATHER_VX_VL &&
6852+
Opcode <= RISCVISD::LAST_VL_VECTOR_OP)
67436853
return true;
67446854
if (Opcode >= RISCVISD::STRICT_FADD_VL &&
67456855
Opcode <= RISCVISD::STRICT_VFROUND_NOEXCEPT_VL)
@@ -21758,6 +21868,12 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
2175821868
NODE_NAME_CASE(VZEXT_VL)
2175921869
NODE_NAME_CASE(VCPOP_VL)
2176021870
NODE_NAME_CASE(VFIRST_VL)
21871+
NODE_NAME_CASE(RI_VZIPEVEN_VL)
21872+
NODE_NAME_CASE(RI_VZIPODD_VL)
21873+
NODE_NAME_CASE(RI_VZIP2A_VL)
21874+
NODE_NAME_CASE(RI_VZIP2B_VL)
21875+
NODE_NAME_CASE(RI_VUNZIP2A_VL)
21876+
NODE_NAME_CASE(RI_VUNZIP2B_VL)
2176121877
NODE_NAME_CASE(READ_CSR)
2176221878
NODE_NAME_CASE(WRITE_CSR)
2176321879
NODE_NAME_CASE(SWAP_CSR)

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,15 @@ enum NodeType : unsigned {
403403
// vfirst.m with additional mask and VL operands.
404404
VFIRST_VL,
405405

406-
LAST_VL_VECTOR_OP = VFIRST_VL,
406+
// XRivosVizip
407+
RI_VZIPEVEN_VL,
408+
RI_VZIPODD_VL,
409+
RI_VZIP2A_VL,
410+
RI_VZIP2B_VL,
411+
RI_VUNZIP2A_VL,
412+
RI_VUNZIP2B_VL,
413+
414+
LAST_VL_VECTOR_OP = RI_VUNZIP2B_VL,
407415

408416
// Read VLENB CSR
409417
READ_VLENB,

llvm/lib/Target/RISCV/RISCVInstrInfoXRivos.td

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,46 @@ defm RI_VUNZIP2A_V : VALU_IV_V<"ri.vunzip2a", 0b001000>;
6767
defm RI_VUNZIP2B_V : VALU_IV_V<"ri.vunzip2b", 0b011000>;
6868
}
6969

70+
// These are modeled after the int binop VL nodes
71+
def ri_vzipeven_vl : SDNode<"RISCVISD::RI_VZIPEVEN_VL", SDT_RISCVIntBinOp_VL>;
72+
def ri_vzipodd_vl : SDNode<"RISCVISD::RI_VZIPODD_VL", SDT_RISCVIntBinOp_VL>;
73+
def ri_vzip2a_vl : SDNode<"RISCVISD::RI_VZIP2A_VL", SDT_RISCVIntBinOp_VL>;
74+
def ri_vunzip2a_vl : SDNode<"RISCVISD::RI_VUNZIP2A_VL", SDT_RISCVIntBinOp_VL>;
75+
def ri_vunzip2b_vl : SDNode<"RISCVISD::RI_VUNZIP2B_VL", SDT_RISCVIntBinOp_VL>;
76+
77+
multiclass RIVPseudoVALU_VV {
78+
foreach m = MxList in {
79+
defvar mx = m.MX;
80+
defm "" : VPseudoBinaryV_VV<m, Commutable=0>;
81+
}
82+
}
83+
84+
let Predicates = [HasVendorXRivosVizip],
85+
Constraints = "@earlyclobber $rd, $rd = $passthru" in {
86+
defm PseudoRI_VZIPEVEN : RIVPseudoVALU_VV;
87+
defm PseudoRI_VZIPODD : RIVPseudoVALU_VV;
88+
defm PseudoRI_VZIP2A : RIVPseudoVALU_VV;
89+
defm PseudoRI_VUNZIP2A : RIVPseudoVALU_VV;
90+
defm PseudoRI_VUNZIP2B : RIVPseudoVALU_VV;
91+
}
92+
93+
multiclass RIVPatBinaryVL_VV<SDPatternOperator vop, string instruction_name,
94+
list<VTypeInfo> vtilist = AllIntegerVectors,
95+
bit isSEWAware = 0> {
96+
foreach vti = vtilist in
97+
let Predicates = GetVTypePredicates<vti>.Predicates in
98+
def : VPatBinaryVL_V<vop, instruction_name, "VV",
99+
vti.Vector, vti.Vector, vti.Vector, vti.Mask,
100+
vti.Log2SEW, vti.LMul, vti.RegClass, vti.RegClass,
101+
vti.RegClass, isSEWAware>;
102+
}
103+
104+
defm : RIVPatBinaryVL_VV<ri_vzipeven_vl, "PseudoRI_VZIPEVEN">;
105+
defm : RIVPatBinaryVL_VV<ri_vzipodd_vl, "PseudoRI_VZIPODD">;
106+
defm : RIVPatBinaryVL_VV<ri_vzip2a_vl, "PseudoRI_VZIP2A">;
107+
defm : RIVPatBinaryVL_VV<ri_vunzip2a_vl, "PseudoRI_VUNZIP2A">;
108+
defm : RIVPatBinaryVL_VV<ri_vunzip2b_vl, "PseudoRI_VUNZIP2B">;
109+
70110
//===----------------------------------------------------------------------===//
71111
// XRivosVisni
72112
//===----------------------------------------------------------------------===//
@@ -87,3 +127,5 @@ def RI_VEXTRACT : CustomRivosXVI<0b010111, OPMVV, (outs GPR:$rd),
87127
(ins VR:$vs2, uimm5:$imm),
88128
"ri.vextract.x.v", "$rd, $vs2, $imm">;
89129
}
130+
131+

0 commit comments

Comments
 (0)