Skip to content

Commit 90f7684

Browse files
authored
[VP][RISCV] Add llvm.experimental.vp.reverse. (llvm#70405)
This is similar to vector.reverse, but only reverses the first EVL elements. I extracted this code from our downstream. Some of it may have come from https://repo.hca.bsc.es/gitlab/rferrer/llvm-epi/ originally.
1 parent b89aadf commit 90f7684

File tree

14 files changed

+1766
-3
lines changed

14 files changed

+1766
-3
lines changed

llvm/docs/LangRef.rst

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21756,6 +21756,42 @@ Examples:
2175621756
llvm.experimental.vp.splice(<A,B,C,D>, <E,F,G,H>, -2, 3, 2); ==> <B, C, poison, poison> trailing elements
2175721757

2175821758

21759+
.. _int_experimental_vp_reverse:
21760+
21761+
21762+
'``llvm.experimental.vp.reverse``' Intrinsic
21763+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
21764+
21765+
Syntax:
21766+
"""""""
21767+
This is an overloaded intrinsic.
21768+
21769+
::
21770+
21771+
declare <2 x double> @llvm.experimental.vp.reverse.v2f64(<2 x double> %vec, <2 x i1> %mask, i32 %evl)
21772+
declare <vscale x 4 x i32> @llvm.experimental.vp.reverse.nxv4i32(<vscale x 4 x i32> %vec, <vscale x 4 x i1> %mask, i32 %evl)
21773+
21774+
Overview:
21775+
"""""""""
21776+
21777+
The '``llvm.experimental.vp.reverse.*``' intrinsic is the vector length
21778+
predicated version of the '``llvm.experimental.vector.reverse.*``' intrinsic.
21779+
21780+
Arguments:
21781+
""""""""""
21782+
21783+
The result and the first argument ``vec`` are vectors with the same type.
21784+
The second argument ``mask`` is a vector mask and has the same number of
21785+
elements as the result. The third argument is the explicit vector length of
21786+
the operation.
21787+
21788+
Semantics:
21789+
""""""""""
21790+
21791+
This intrinsic reverses the order of the first ``evl`` elements in a vector.
21792+
The lanes in the result vector disabled by ``mask`` are ``poison``. The
21793+
elements past ``evl`` are poison.
21794+
2175921795
.. _int_vp_load:
2176021796

2176121797
'``llvm.vp.load``' Intrinsic

llvm/include/llvm/IR/Intrinsics.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2196,6 +2196,13 @@ def int_experimental_vp_splice:
21962196
llvm_i32_ty, llvm_i32_ty],
21972197
[IntrNoMem, ImmArg<ArgIndex<2>>]>;
21982198

2199+
def int_experimental_vp_reverse:
2200+
DefaultAttrsIntrinsic<[llvm_anyvector_ty],
2201+
[LLVMMatchType<0>,
2202+
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
2203+
llvm_i32_ty],
2204+
[IntrNoMem]>;
2205+
21992206
def int_vp_is_fpclass:
22002207
DefaultAttrsIntrinsic<[ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>],
22012208
[ llvm_anyvector_ty,

llvm/include/llvm/IR/VPIntrinsics.def

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,12 @@ BEGIN_REGISTER_VP(experimental_vp_splice, 3, 5, EXPERIMENTAL_VP_SPLICE, -1)
702702
VP_PROPERTY_NO_FUNCTIONAL
703703
END_REGISTER_VP(experimental_vp_splice, EXPERIMENTAL_VP_SPLICE)
704704

705+
// llvm.experimental.vp.reverse(x,mask,vlen)
706+
BEGIN_REGISTER_VP(experimental_vp_reverse, 1, 2,
707+
EXPERIMENTAL_VP_REVERSE, -1)
708+
VP_PROPERTY_NO_FUNCTIONAL
709+
END_REGISTER_VP(experimental_vp_reverse, EXPERIMENTAL_VP_REVERSE)
710+
705711
///// } Shuffles
706712

707713
#undef BEGIN_REGISTER_VP

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
889889
void SplitVecRes_VECTOR_INTERLEAVE(SDNode *N);
890890
void SplitVecRes_VAARG(SDNode *N, SDValue &Lo, SDValue &Hi);
891891
void SplitVecRes_FP_TO_XINT_SAT(SDNode *N, SDValue &Lo, SDValue &Hi);
892+
void SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo, SDValue &Hi);
892893

893894
// Vector Operand Splitting: <128 x ty> -> 2 x <64 x ty>.
894895
bool SplitVectorOperand(SDNode *N, unsigned OpNo);

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,6 +1209,9 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
12091209
case ISD::UDIVFIXSAT:
12101210
SplitVecRes_FIX(N, Lo, Hi);
12111211
break;
1212+
case ISD::EXPERIMENTAL_VP_REVERSE:
1213+
SplitVecRes_VP_REVERSE(N, Lo, Hi);
1214+
break;
12121215
}
12131216

12141217
// If Lo/Hi is null, the sub-method took care of registering results etc.
@@ -2857,6 +2860,56 @@ void DAGTypeLegalizer::SplitVecRes_VECTOR_SPLICE(SDNode *N, SDValue &Lo,
28572860
DAG.getVectorIdxConstant(LoVT.getVectorMinNumElements(), DL));
28582861
}
28592862

2863+
void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
2864+
SDValue &Hi) {
2865+
EVT VT = N->getValueType(0);
2866+
SDValue Val = N->getOperand(0);
2867+
SDValue Mask = N->getOperand(1);
2868+
SDValue EVL = N->getOperand(2);
2869+
SDLoc DL(N);
2870+
2871+
// Fallback to VP_STRIDED_STORE to stack followed by VP_LOAD.
2872+
Align Alignment = DAG.getReducedAlign(VT, /*UseABI=*/false);
2873+
2874+
EVT MemVT = EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
2875+
VT.getVectorElementCount());
2876+
SDValue StackPtr = DAG.CreateStackTemporary(MemVT.getStoreSize(), Alignment);
2877+
EVT PtrVT = StackPtr.getValueType();
2878+
auto &MF = DAG.getMachineFunction();
2879+
auto FrameIndex = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
2880+
auto PtrInfo = MachinePointerInfo::getFixedStack(MF, FrameIndex);
2881+
2882+
MachineMemOperand *StoreMMO = DAG.getMachineFunction().getMachineMemOperand(
2883+
PtrInfo, MachineMemOperand::MOStore, MemoryLocation::UnknownSize,
2884+
Alignment);
2885+
MachineMemOperand *LoadMMO = DAG.getMachineFunction().getMachineMemOperand(
2886+
PtrInfo, MachineMemOperand::MOLoad, MemoryLocation::UnknownSize,
2887+
Alignment);
2888+
2889+
unsigned EltWidth = VT.getScalarSizeInBits() / 8;
2890+
SDValue NumElemMinus1 =
2891+
DAG.getNode(ISD::SUB, DL, PtrVT, DAG.getZExtOrTrunc(EVL, DL, PtrVT),
2892+
DAG.getConstant(1, DL, PtrVT));
2893+
SDValue StartOffset = DAG.getNode(ISD::MUL, DL, PtrVT, NumElemMinus1,
2894+
DAG.getConstant(EltWidth, DL, PtrVT));
2895+
SDValue StorePtr = DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr, StartOffset);
2896+
SDValue Stride = DAG.getConstant(-(int64_t)EltWidth, DL, PtrVT);
2897+
2898+
SDValue TrueMask = DAG.getBoolConstant(true, DL, Mask.getValueType(), VT);
2899+
SDValue Store = DAG.getStridedStoreVP(DAG.getEntryNode(), DL, Val, StorePtr,
2900+
DAG.getUNDEF(PtrVT), Stride, TrueMask,
2901+
EVL, MemVT, StoreMMO, ISD::UNINDEXED);
2902+
2903+
SDValue Load = DAG.getLoadVP(VT, DL, Store, StackPtr, Mask, EVL, LoadMMO);
2904+
2905+
auto [LoVT, HiVT] = DAG.GetSplitDestVTs(VT);
2906+
Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, LoVT, Load,
2907+
DAG.getVectorIdxConstant(0, DL));
2908+
Hi =
2909+
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HiVT, Load,
2910+
DAG.getVectorIdxConstant(LoVT.getVectorMinNumElements(), DL));
2911+
}
2912+
28602913
void DAGTypeLegalizer::SplitVecRes_VECTOR_DEINTERLEAVE(SDNode *N) {
28612914

28622915
SDValue Op0Lo, Op0Hi, Op1Lo, Op1Hi;

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 132 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
662662
ISD::VP_FP_TO_UINT, ISD::VP_SETCC, ISD::VP_SIGN_EXTEND,
663663
ISD::VP_ZERO_EXTEND, ISD::VP_TRUNCATE, ISD::VP_SMIN,
664664
ISD::VP_SMAX, ISD::VP_UMIN, ISD::VP_UMAX,
665-
ISD::VP_ABS};
665+
ISD::VP_ABS, ISD::EXPERIMENTAL_VP_REVERSE};
666666

667667
static const unsigned FloatingPointVPOps[] = {
668668
ISD::VP_FADD, ISD::VP_FSUB, ISD::VP_FMUL,
@@ -674,7 +674,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
674674
ISD::VP_SQRT, ISD::VP_FMINNUM, ISD::VP_FMAXNUM,
675675
ISD::VP_FCEIL, ISD::VP_FFLOOR, ISD::VP_FROUND,
676676
ISD::VP_FROUNDEVEN, ISD::VP_FCOPYSIGN, ISD::VP_FROUNDTOZERO,
677-
ISD::VP_FRINT, ISD::VP_FNEARBYINT, ISD::VP_IS_FPCLASS};
677+
ISD::VP_FRINT, ISD::VP_FNEARBYINT, ISD::VP_IS_FPCLASS,
678+
ISD::EXPERIMENTAL_VP_REVERSE};
678679

679680
static const unsigned IntegerVecReduceOps[] = {
680681
ISD::VECREDUCE_ADD, ISD::VECREDUCE_AND, ISD::VECREDUCE_OR,
@@ -759,6 +760,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
759760

760761
setOperationAction(ISD::VECTOR_REVERSE, VT, Custom);
761762

763+
setOperationAction(ISD::EXPERIMENTAL_VP_REVERSE, VT, Custom);
764+
762765
setOperationPromotedToType(
763766
ISD::VECTOR_SPLICE, VT,
764767
MVT::getVectorVT(MVT::i8, VT.getVectorElementCount()));
@@ -1129,6 +1132,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
11291132
setOperationAction({ISD::VP_FP_TO_SINT, ISD::VP_FP_TO_UINT,
11301133
ISD::VP_SETCC, ISD::VP_TRUNCATE},
11311134
VT, Custom);
1135+
1136+
setOperationAction(ISD::EXPERIMENTAL_VP_REVERSE, VT, Custom);
11321137
continue;
11331138
}
11341139

@@ -1383,7 +1388,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
13831388
setTargetDAGCombine({ISD::FCOPYSIGN, ISD::MGATHER, ISD::MSCATTER,
13841389
ISD::VP_GATHER, ISD::VP_SCATTER, ISD::SRA, ISD::SRL,
13851390
ISD::SHL, ISD::STORE, ISD::SPLAT_VECTOR,
1386-
ISD::BUILD_VECTOR, ISD::CONCAT_VECTORS});
1391+
ISD::BUILD_VECTOR, ISD::CONCAT_VECTORS,
1392+
ISD::EXPERIMENTAL_VP_REVERSE});
13871393
if (Subtarget.hasVendorXTHeadMemPair())
13881394
setTargetDAGCombine({ISD::LOAD, ISD::STORE});
13891395
if (Subtarget.useRVVForFixedLengthVectors())
@@ -6518,6 +6524,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
65186524
!Subtarget.hasVInstructionsF16()))
65196525
return SplitVPOp(Op, DAG);
65206526
return lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND(Op, DAG, Subtarget);
6527+
case ISD::EXPERIMENTAL_VP_REVERSE:
6528+
return lowerVPReverseExperimental(Op, DAG);
65216529
}
65226530
}
65236531

@@ -10378,6 +10386,127 @@ SDValue RISCVTargetLowering::lowerVPFPIntConvOp(SDValue Op,
1037810386
return convertFromScalableVector(VT, Result, DAG, Subtarget);
1037910387
}
1038010388

10389+
SDValue
10390+
RISCVTargetLowering::lowerVPReverseExperimental(SDValue Op,
10391+
SelectionDAG &DAG) const {
10392+
SDLoc DL(Op);
10393+
MVT VT = Op.getSimpleValueType();
10394+
MVT XLenVT = Subtarget.getXLenVT();
10395+
10396+
SDValue Op1 = Op.getOperand(0);
10397+
SDValue Mask = Op.getOperand(1);
10398+
SDValue EVL = Op.getOperand(2);
10399+
10400+
MVT ContainerVT = VT;
10401+
if (VT.isFixedLengthVector()) {
10402+
ContainerVT = getContainerForFixedLengthVector(VT);
10403+
Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
10404+
MVT MaskVT = getMaskTypeFor(ContainerVT);
10405+
Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
10406+
}
10407+
10408+
MVT GatherVT = ContainerVT;
10409+
MVT IndicesVT = ContainerVT.changeVectorElementTypeToInteger();
10410+
// Check if we are working with mask vectors
10411+
bool IsMaskVector = ContainerVT.getVectorElementType() == MVT::i1;
10412+
if (IsMaskVector) {
10413+
GatherVT = IndicesVT = ContainerVT.changeVectorElementType(MVT::i8);
10414+
10415+
// Expand input operand
10416+
SDValue SplatOne = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, IndicesVT,
10417+
DAG.getUNDEF(IndicesVT),
10418+
DAG.getConstant(1, DL, XLenVT), EVL);
10419+
SDValue SplatZero = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, IndicesVT,
10420+
DAG.getUNDEF(IndicesVT),
10421+
DAG.getConstant(0, DL, XLenVT), EVL);
10422+
Op1 = DAG.getNode(RISCVISD::VSELECT_VL, DL, IndicesVT, Op1, SplatOne,
10423+
SplatZero, EVL);
10424+
}
10425+
10426+
unsigned EltSize = GatherVT.getScalarSizeInBits();
10427+
unsigned MinSize = GatherVT.getSizeInBits().getKnownMinValue();
10428+
unsigned VectorBitsMax = Subtarget.getRealMaxVLen();
10429+
unsigned MaxVLMAX =
10430+
RISCVTargetLowering::computeVLMAX(VectorBitsMax, EltSize, MinSize);
10431+
10432+
unsigned GatherOpc = RISCVISD::VRGATHER_VV_VL;
10433+
// If this is SEW=8 and VLMAX is unknown or more than 256, we need
10434+
// to use vrgatherei16.vv.
10435+
// TODO: It's also possible to use vrgatherei16.vv for other types to
10436+
// decrease register width for the index calculation.
10437+
// NOTE: This code assumes VLMAX <= 65536 for LMUL=8 SEW=16.
10438+
if (MaxVLMAX > 256 && EltSize == 8) {
10439+
// If this is LMUL=8, we have to split before using vrgatherei16.vv.
10440+
// Split the vector in half and reverse each half using a full register
10441+
// reverse.
10442+
// Swap the halves and concatenate them.
10443+
// Slide the concatenated result by (VLMax - VL).
10444+
if (MinSize == (8 * RISCV::RVVBitsPerBlock)) {
10445+
auto [LoVT, HiVT] = DAG.GetSplitDestVTs(GatherVT);
10446+
auto [Lo, Hi] = DAG.SplitVector(Op1, DL);
10447+
10448+
SDValue LoRev = DAG.getNode(ISD::VECTOR_REVERSE, DL, LoVT, Lo);
10449+
SDValue HiRev = DAG.getNode(ISD::VECTOR_REVERSE, DL, HiVT, Hi);
10450+
10451+
// Reassemble the low and high pieces reversed.
10452+
// NOTE: this Result is unmasked (because we do not need masks for
10453+
// shuffles). If in the future this has to change, we can use a SELECT_VL
10454+
// between Result and UNDEF using the mask originally passed to VP_REVERSE
10455+
SDValue Result =
10456+
DAG.getNode(ISD::CONCAT_VECTORS, DL, GatherVT, HiRev, LoRev);
10457+
10458+
// Slide off any elements from past EVL that were reversed into the low
10459+
// elements.
10460+
unsigned MinElts = GatherVT.getVectorMinNumElements();
10461+
SDValue VLMax = DAG.getNode(ISD::VSCALE, DL, XLenVT,
10462+
DAG.getConstant(MinElts, DL, XLenVT));
10463+
SDValue Diff = DAG.getNode(ISD::SUB, DL, XLenVT, VLMax, EVL);
10464+
10465+
Result = getVSlidedown(DAG, Subtarget, DL, GatherVT,
10466+
DAG.getUNDEF(GatherVT), Result, Diff, Mask, EVL);
10467+
10468+
if (IsMaskVector) {
10469+
// Truncate Result back to a mask vector
10470+
Result =
10471+
DAG.getNode(RISCVISD::SETCC_VL, DL, ContainerVT,
10472+
{Result, DAG.getConstant(0, DL, GatherVT),
10473+
DAG.getCondCode(ISD::SETNE),
10474+
DAG.getUNDEF(getMaskTypeFor(ContainerVT)), Mask, EVL});
10475+
}
10476+
10477+
if (!VT.isFixedLengthVector())
10478+
return Result;
10479+
return convertFromScalableVector(VT, Result, DAG, Subtarget);
10480+
}
10481+
10482+
// Just promote the int type to i16 which will double the LMUL.
10483+
IndicesVT = MVT::getVectorVT(MVT::i16, IndicesVT.getVectorElementCount());
10484+
GatherOpc = RISCVISD::VRGATHEREI16_VV_VL;
10485+
}
10486+
10487+
SDValue VID = DAG.getNode(RISCVISD::VID_VL, DL, IndicesVT, Mask, EVL);
10488+
SDValue VecLen =
10489+
DAG.getNode(ISD::SUB, DL, XLenVT, EVL, DAG.getConstant(1, DL, XLenVT));
10490+
SDValue VecLenSplat = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, IndicesVT,
10491+
DAG.getUNDEF(IndicesVT), VecLen, EVL);
10492+
SDValue VRSUB = DAG.getNode(RISCVISD::SUB_VL, DL, IndicesVT, VecLenSplat, VID,
10493+
DAG.getUNDEF(IndicesVT), Mask, EVL);
10494+
SDValue Result = DAG.getNode(GatherOpc, DL, GatherVT, Op1, VRSUB,
10495+
DAG.getUNDEF(GatherVT), Mask, EVL);
10496+
10497+
if (IsMaskVector) {
10498+
// Truncate Result back to a mask vector
10499+
Result = DAG.getNode(
10500+
RISCVISD::SETCC_VL, DL, ContainerVT,
10501+
{Result, DAG.getConstant(0, DL, GatherVT), DAG.getCondCode(ISD::SETNE),
10502+
DAG.getUNDEF(getMaskTypeFor(ContainerVT)), Mask, EVL});
10503+
}
10504+
10505+
if (!VT.isFixedLengthVector())
10506+
return Result;
10507+
return convertFromScalableVector(VT, Result, DAG, Subtarget);
10508+
}
10509+
1038110510
SDValue RISCVTargetLowering::lowerLogicVPOp(SDValue Op,
1038210511
SelectionDAG &DAG) const {
1038310512
MVT VT = Op.getSimpleValueType();

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,7 @@ class RISCVTargetLowering : public TargetLowering {
901901
SDValue lowerLogicVPOp(SDValue Op, SelectionDAG &DAG) const;
902902
SDValue lowerVPExtMaskOp(SDValue Op, SelectionDAG &DAG) const;
903903
SDValue lowerVPSetCCMaskOp(SDValue Op, SelectionDAG &DAG) const;
904+
SDValue lowerVPReverseExperimental(SDValue Op, SelectionDAG &DAG) const;
904905
SDValue lowerVPFPIntConvOp(SDValue Op, SelectionDAG &DAG) const;
905906
SDValue lowerVPStridedLoad(SDValue Op, SelectionDAG &DAG) const;
906907
SDValue lowerVPStridedStore(SDValue Op, SelectionDAG &DAG) const;

0 commit comments

Comments
 (0)