Skip to content
Merged
105 changes: 100 additions & 5 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4553,8 +4553,10 @@ static SDValue getSingleShuffleSrc(MVT VT, SDValue V1, SDValue V2) {
/// way through the source.
static bool isInterleaveShuffle(ArrayRef<int> Mask, MVT VT, int &EvenSrc,
int &OddSrc, const RISCVSubtarget &Subtarget) {
// We need to be able to widen elements to the next larger integer type.
if (VT.getScalarSizeInBits() >= Subtarget.getELen())
// We need to be able to widen elements to the next larger integer type or
// use the zip2a instruction at e64.
if (VT.getScalarSizeInBits() >= Subtarget.getELen() &&
!Subtarget.hasVendorXRivosVizip())
return false;

int Size = Mask.size();
Expand Down Expand Up @@ -4611,6 +4613,43 @@ static bool isElementRotate(std::array<std::pair<int, int>, 2> &SrcInfo,
SrcInfo[1].second - SrcInfo[0].second == (int)NumElts;
}

static bool isAlternating(std::array<std::pair<int, int>, 2> &SrcInfo,
ArrayRef<int> Mask, bool &Polarity) {
int NumElts = Mask.size();
bool NonUndefFound = false;
for (unsigned i = 0; i != Mask.size(); ++i) {
int M = Mask[i];
if (M < 0)
continue;
int Src = M >= (int)NumElts;
int Diff = (int)i - (M % NumElts);
bool C = Src == SrcInfo[1].first && Diff == SrcInfo[1].second;
if (!NonUndefFound) {
NonUndefFound = true;
Polarity = (C == i % 2);
continue;
}
if ((Polarity && C != i % 2) || (!Polarity && C == i % 2))
return false;
}
return true;
}

static bool isZipEven(std::array<std::pair<int, int>, 2> &SrcInfo,
ArrayRef<int> Mask) {
bool Polarity;
return SrcInfo[0].second == 0 && SrcInfo[1].second == 1 &&
isAlternating(SrcInfo, Mask, Polarity) && Polarity;
;
}

static bool isZipOdd(std::array<std::pair<int, int>, 2> &SrcInfo,
ArrayRef<int> Mask) {
bool Polarity;
return SrcInfo[0].second == 0 && SrcInfo[1].second == -1 &&
isAlternating(SrcInfo, Mask, Polarity) && !Polarity;
}

// Lower a deinterleave shuffle to SRL and TRUNC. Factor must be
// 2, 4, 8 and the integer type Factor-times larger than VT's
// element type must be a legal element type.
Expand Down Expand Up @@ -4870,6 +4909,34 @@ static bool isSpreadMask(ArrayRef<int> Mask, unsigned Factor, unsigned &Index) {
return true;
}

static SDValue lowerVIZIP(unsigned Opc, SDValue Op0, SDValue Op1,
const SDLoc &DL, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
assert(RISCVISD::RI_VZIPEVEN_VL == Opc || RISCVISD::RI_VZIPODD_VL == Opc ||
RISCVISD::RI_VZIP2A_VL == Opc);
assert(Op0.getSimpleValueType() == Op1.getSimpleValueType());

MVT VT = Op0.getSimpleValueType();
MVT IntVT = VT.changeVectorElementTypeToInteger();
Op0 = DAG.getBitcast(IntVT, Op0);
Op1 = DAG.getBitcast(IntVT, Op1);

MVT ContainerVT = IntVT;
if (VT.isFixedLengthVector()) {
ContainerVT = getContainerForFixedLengthVector(DAG, IntVT, Subtarget);
Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
}

auto [Mask, VL] = getDefaultVLOps(IntVT, ContainerVT, DL, DAG, Subtarget);
SDValue Passthru = DAG.getUNDEF(ContainerVT);
SDValue Res = DAG.getNode(Opc, DL, ContainerVT, Op0, Op1, Passthru, Mask, VL);
if (IntVT.isFixedLengthVector())
Res = convertFromScalableVector(IntVT, Res, DAG, Subtarget);
Res = DAG.getBitcast(VT, Res);
return Res;
}

// Given a vector a, b, c, d return a vector Factor times longer
// with Factor-1 undef's between elements. Ex:
// a, undef, b, undef, c, undef, d, undef (Factor=2, Index=0)
Expand Down Expand Up @@ -5556,6 +5623,7 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
}
}


if (SDValue V =
lowerVECTOR_SHUFFLEAsVSlideup(DL, VT, V1, V2, Mask, Subtarget, DAG))
return V;
Expand Down Expand Up @@ -5596,6 +5664,16 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
DAG.getVectorIdxConstant(OddSrc % Size, DL));
}

// Prefer vzip2a if available.
// TODO: Extend to matching zip2b if EvenSrc and OddSrc allow.
if (Subtarget.hasVendorXRivosVizip()) {
EvenV = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT),
EvenV, DAG.getVectorIdxConstant(0, DL));
OddV = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT), OddV,
DAG.getVectorIdxConstant(0, DL));
return lowerVIZIP(RISCVISD::RI_VZIP2A_VL, EvenV, OddV, DL, DAG,
Subtarget);
}
return getWideningInterleave(EvenV, OddV, DL, DAG, Subtarget);
}

Expand Down Expand Up @@ -5647,6 +5725,19 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
return convertFromScalableVector(VT, Res, DAG, Subtarget);
}

if (Subtarget.hasVendorXRivosVizip() && isZipEven(SrcInfo, Mask)) {
SDValue Src1 = SrcInfo[0].first == 0 ? V1 : V2;
SDValue Src2 = SrcInfo[1].first == 0 ? V1 : V2;
return lowerVIZIP(RISCVISD::RI_VZIPEVEN_VL, Src1, Src2, DL, DAG,
Subtarget);
}
if (Subtarget.hasVendorXRivosVizip() && isZipOdd(SrcInfo, Mask)) {
SDValue Src1 = SrcInfo[1].first == 0 ? V1 : V2;
SDValue Src2 = SrcInfo[0].first == 0 ? V1 : V2;
return lowerVIZIP(RISCVISD::RI_VZIPODD_VL, Src1, Src2, DL, DAG,
Subtarget);
}

// Build the mask. Note that vslideup unconditionally preserves elements
// below the slide amount in the destination, and thus those elements are
// undefined in the mask. If the mask ends up all true (or undef), it
Expand Down Expand Up @@ -6710,7 +6801,7 @@ static bool hasPassthruOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(
RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 127 &&
RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 130 &&
RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 &&
"adding target specific op should update this function");
if (Opcode >= RISCVISD::ADD_VL && Opcode <= RISCVISD::VFMAX_VL)
Expand All @@ -6734,12 +6825,13 @@ static bool hasMaskOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(
RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 127 &&
RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 130 &&
RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 &&
"adding target specific op should update this function");
if (Opcode >= RISCVISD::TRUNCATE_VECTOR_VL && Opcode <= RISCVISD::SETCC_VL)
return true;
if (Opcode >= RISCVISD::VRGATHER_VX_VL && Opcode <= RISCVISD::VFIRST_VL)
if (Opcode >= RISCVISD::VRGATHER_VX_VL &&
Opcode <= RISCVISD::LAST_VL_VECTOR_OP)
return true;
if (Opcode >= RISCVISD::STRICT_FADD_VL &&
Opcode <= RISCVISD::STRICT_VFROUND_NOEXCEPT_VL)
Expand Down Expand Up @@ -21758,6 +21850,9 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VZEXT_VL)
NODE_NAME_CASE(VCPOP_VL)
NODE_NAME_CASE(VFIRST_VL)
NODE_NAME_CASE(RI_VZIPEVEN_VL)
NODE_NAME_CASE(RI_VZIPODD_VL)
NODE_NAME_CASE(RI_VZIP2A_VL)
NODE_NAME_CASE(READ_CSR)
NODE_NAME_CASE(WRITE_CSR)
NODE_NAME_CASE(SWAP_CSR)
Expand Down
7 changes: 6 additions & 1 deletion llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,12 @@ enum NodeType : unsigned {
// vfirst.m with additional mask and VL operands.
VFIRST_VL,

LAST_VL_VECTOR_OP = VFIRST_VL,
// XRivosVizip
RI_VZIPEVEN_VL,
RI_VZIPODD_VL,
RI_VZIP2A_VL,

LAST_VL_VECTOR_OP = RI_VZIP2A_VL,

// Read VLENB CSR
READ_VLENB,
Expand Down
34 changes: 34 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfoXRivos.td
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,40 @@ defm RI_VUNZIP2A_V : VALU_IV_V<"ri.vunzip2a", 0b001000>;
defm RI_VUNZIP2B_V : VALU_IV_V<"ri.vunzip2b", 0b011000>;
}

// These are modeled after the int binop VL nodes
def ri_vzipeven_vl : SDNode<"RISCVISD::RI_VZIPEVEN_VL", SDT_RISCVIntBinOp_VL>;
def ri_vzipodd_vl : SDNode<"RISCVISD::RI_VZIPODD_VL", SDT_RISCVIntBinOp_VL>;
def ri_vzip2a_vl : SDNode<"RISCVISD::RI_VZIP2A_VL", SDT_RISCVIntBinOp_VL>;

multiclass RIVPseudoVALU_VV {
foreach m = MxList in {
defvar mx = m.MX;
defm "" : VPseudoBinaryV_VV<m, Commutable=0>;
}
}

let Predicates = [HasVendorXRivosVizip],
Constraints = "@earlyclobber $rd, $rd = $passthru" in {
defm PseudoRI_VZIPEVEN : RIVPseudoVALU_VV;
defm PseudoRI_VZIPODD : RIVPseudoVALU_VV;
defm PseudoRI_VZIP2A : RIVPseudoVALU_VV;
}

multiclass RIVPatBinaryVL_VV<SDPatternOperator vop, string instruction_name,
list<VTypeInfo> vtilist = AllIntegerVectors,
bit isSEWAware = 0> {
foreach vti = vtilist in
let Predicates = GetVTypePredicates<vti>.Predicates in
def : VPatBinaryVL_V<vop, instruction_name, "VV",
vti.Vector, vti.Vector, vti.Vector, vti.Mask,
vti.Log2SEW, vti.LMul, vti.RegClass, vti.RegClass,
vti.RegClass, isSEWAware>;
}

defm : RIVPatBinaryVL_VV<ri_vzipeven_vl, "PseudoRI_VZIPEVEN">;
defm : RIVPatBinaryVL_VV<ri_vzipodd_vl, "PseudoRI_VZIPODD">;
defm : RIVPatBinaryVL_VV<ri_vzip2a_vl, "PseudoRI_VZIP2A">;

//===----------------------------------------------------------------------===//
// XRivosVisni
//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading