Skip to content

Commit 659db66

Browse files
committed
[RISCV] Add intrinsics for strided segment loads with fixed vectors
1 parent d2361e4 commit 659db66

File tree

3 files changed

+189
-50
lines changed

3 files changed

+189
-50
lines changed

llvm/include/llvm/IR/IntrinsicsRISCV.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,6 +1717,16 @@ let TargetPrefix = "riscv" in {
17171717
llvm_anyint_ty],
17181718
[NoCapture<ArgIndex<0>>, IntrReadMem]>;
17191719

1720+
// Input: (pointer, offset, mask, vl)
1721+
def int_riscv_sseg # nf # _load_mask
1722+
: DefaultAttrsIntrinsic<!listconcat([llvm_anyvector_ty],
1723+
!listsplat(LLVMMatchType<0>,
1724+
!add(nf, -1))),
1725+
[llvm_anyptr_ty, llvm_anyint_ty,
1726+
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
1727+
llvm_anyint_ty],
1728+
[NoCapture<ArgIndex<0>>, IntrReadMem]>;
1729+
17201730
// Input: (<stored values>..., pointer, mask, vl)
17211731
def int_riscv_seg # nf # _store_mask
17221732
: DefaultAttrsIntrinsic<[],

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 107 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1819,6 +1819,13 @@ bool RISCVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
18191819
case Intrinsic::riscv_seg6_load_mask:
18201820
case Intrinsic::riscv_seg7_load_mask:
18211821
case Intrinsic::riscv_seg8_load_mask:
1822+
case Intrinsic::riscv_sseg2_load_mask:
1823+
case Intrinsic::riscv_sseg3_load_mask:
1824+
case Intrinsic::riscv_sseg4_load_mask:
1825+
case Intrinsic::riscv_sseg5_load_mask:
1826+
case Intrinsic::riscv_sseg6_load_mask:
1827+
case Intrinsic::riscv_sseg7_load_mask:
1828+
case Intrinsic::riscv_sseg8_load_mask:
18221829
return SetRVVLoadStoreInfo(/*PtrOp*/ 0, /*IsStore*/ false,
18231830
/*IsUnitStrided*/ false, /*UsePtrVal*/ true);
18241831
case Intrinsic::riscv_seg2_store_mask:
@@ -10959,6 +10966,97 @@ static inline SDValue getVCIXISDNodeVOID(SDValue &Op, SelectionDAG &DAG,
1095910966
return DAG.getNode(Type, SDLoc(Op), Op.getValueType(), Operands);
1096010967
}
1096110968

10969+
static SDValue
10970+
convertFixedVectorSegLoadIntrinsics(unsigned IntNo, SDValue Op,
10971+
const RISCVSubtarget &Subtarget,
10972+
SelectionDAG &DAG) {
10973+
bool IsStrided;
10974+
switch (IntNo) {
10975+
case Intrinsic::riscv_seg2_load_mask:
10976+
case Intrinsic::riscv_seg3_load_mask:
10977+
case Intrinsic::riscv_seg4_load_mask:
10978+
case Intrinsic::riscv_seg5_load_mask:
10979+
case Intrinsic::riscv_seg6_load_mask:
10980+
case Intrinsic::riscv_seg7_load_mask:
10981+
case Intrinsic::riscv_seg8_load_mask:
10982+
IsStrided = false;
10983+
break;
10984+
case Intrinsic::riscv_sseg2_load_mask:
10985+
case Intrinsic::riscv_sseg3_load_mask:
10986+
case Intrinsic::riscv_sseg4_load_mask:
10987+
case Intrinsic::riscv_sseg5_load_mask:
10988+
case Intrinsic::riscv_sseg6_load_mask:
10989+
case Intrinsic::riscv_sseg7_load_mask:
10990+
case Intrinsic::riscv_sseg8_load_mask:
10991+
IsStrided = true;
10992+
break;
10993+
default:
10994+
llvm_unreachable("unexpected intrinsic ID");
10995+
};
10996+
10997+
static const Intrinsic::ID VlsegInts[7] = {
10998+
Intrinsic::riscv_vlseg2_mask, Intrinsic::riscv_vlseg3_mask,
10999+
Intrinsic::riscv_vlseg4_mask, Intrinsic::riscv_vlseg5_mask,
11000+
Intrinsic::riscv_vlseg6_mask, Intrinsic::riscv_vlseg7_mask,
11001+
Intrinsic::riscv_vlseg8_mask};
11002+
static const Intrinsic::ID VlssegInts[7] = {
11003+
Intrinsic::riscv_vlsseg2_mask, Intrinsic::riscv_vlsseg3_mask,
11004+
Intrinsic::riscv_vlsseg4_mask, Intrinsic::riscv_vlsseg5_mask,
11005+
Intrinsic::riscv_vlsseg6_mask, Intrinsic::riscv_vlsseg7_mask,
11006+
Intrinsic::riscv_vlsseg8_mask};
11007+
11008+
SDLoc DL(Op);
11009+
unsigned NF = Op->getNumValues() - 1;
11010+
assert(NF >= 2 && NF <= 8 && "Unexpected seg number");
11011+
MVT XLenVT = Subtarget.getXLenVT();
11012+
MVT VT = Op->getSimpleValueType(0);
11013+
MVT ContainerVT = ::getContainerForFixedLengthVector(DAG, VT, Subtarget);
11014+
unsigned Sz = NF * ContainerVT.getVectorMinNumElements() *
11015+
ContainerVT.getScalarSizeInBits();
11016+
EVT VecTupTy = MVT::getRISCVVectorTupleVT(Sz, NF);
11017+
11018+
// Operands: (chain, int_id, pointer, mask, vl) or
11019+
// (chain, int_id, pointer, offset, mask, vl)
11020+
SDValue VL = Op.getOperand(Op.getNumOperands() - 1);
11021+
SDValue Mask = Op.getOperand(Op.getNumOperands() - 2);
11022+
MVT MaskVT = Mask.getSimpleValueType();
11023+
MVT MaskContainerVT =
11024+
::getContainerForFixedLengthVector(DAG, MaskVT, Subtarget);
11025+
Mask = convertToScalableVector(MaskContainerVT, Mask, DAG, Subtarget);
11026+
11027+
SDValue IntID = DAG.getTargetConstant(
11028+
IsStrided ? VlssegInts[NF - 2] : VlsegInts[NF - 2], DL, XLenVT);
11029+
auto *Load = cast<MemIntrinsicSDNode>(Op);
11030+
11031+
SDVTList VTs = DAG.getVTList({VecTupTy, MVT::Other});
11032+
SmallVector<SDValue, 9> Ops = {
11033+
Load->getChain(),
11034+
IntID,
11035+
DAG.getUNDEF(VecTupTy),
11036+
Op.getOperand(2),
11037+
Mask,
11038+
VL,
11039+
DAG.getTargetConstant(
11040+
RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC, DL, XLenVT),
11041+
DAG.getTargetConstant(Log2_64(VT.getScalarSizeInBits()), DL, XLenVT)};
11042+
// Insert the stride operand.
11043+
if (IsStrided)
11044+
Ops.insert(std::next(Ops.begin(), 4), Op.getOperand(3));
11045+
11046+
SDValue Result =
11047+
DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
11048+
Load->getMemoryVT(), Load->getMemOperand());
11049+
SmallVector<SDValue, 9> Results;
11050+
for (unsigned int RetIdx = 0; RetIdx < NF; RetIdx++) {
11051+
SDValue SubVec = DAG.getNode(RISCVISD::TUPLE_EXTRACT, DL, ContainerVT,
11052+
Result.getValue(0),
11053+
DAG.getTargetConstant(RetIdx, DL, MVT::i32));
11054+
Results.push_back(convertFromScalableVector(VT, SubVec, DAG, Subtarget));
11055+
}
11056+
Results.push_back(Result.getValue(1));
11057+
return DAG.getMergeValues(Results, DL);
11058+
}
11059+
1096211060
SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
1096311061
SelectionDAG &DAG) const {
1096411062
unsigned IntNo = Op.getConstantOperandVal(1);
@@ -10971,57 +11069,16 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
1097111069
case Intrinsic::riscv_seg5_load_mask:
1097211070
case Intrinsic::riscv_seg6_load_mask:
1097311071
case Intrinsic::riscv_seg7_load_mask:
10974-
case Intrinsic::riscv_seg8_load_mask: {
10975-
SDLoc DL(Op);
10976-
static const Intrinsic::ID VlsegInts[7] = {
10977-
Intrinsic::riscv_vlseg2_mask, Intrinsic::riscv_vlseg3_mask,
10978-
Intrinsic::riscv_vlseg4_mask, Intrinsic::riscv_vlseg5_mask,
10979-
Intrinsic::riscv_vlseg6_mask, Intrinsic::riscv_vlseg7_mask,
10980-
Intrinsic::riscv_vlseg8_mask};
10981-
unsigned NF = Op->getNumValues() - 1;
10982-
assert(NF >= 2 && NF <= 8 && "Unexpected seg number");
10983-
MVT XLenVT = Subtarget.getXLenVT();
10984-
MVT VT = Op->getSimpleValueType(0);
10985-
MVT ContainerVT = getContainerForFixedLengthVector(VT);
10986-
unsigned Sz = NF * ContainerVT.getVectorMinNumElements() *
10987-
ContainerVT.getScalarSizeInBits();
10988-
EVT VecTupTy = MVT::getRISCVVectorTupleVT(Sz, NF);
10989-
10990-
// Operands: (chain, int_id, pointer, mask, vl)
10991-
SDValue VL = Op.getOperand(Op.getNumOperands() - 1);
10992-
SDValue Mask = Op.getOperand(3);
10993-
MVT MaskVT = Mask.getSimpleValueType();
10994-
MVT MaskContainerVT =
10995-
::getContainerForFixedLengthVector(DAG, MaskVT, Subtarget);
10996-
Mask = convertToScalableVector(MaskContainerVT, Mask, DAG, Subtarget);
10997-
10998-
SDValue IntID = DAG.getTargetConstant(VlsegInts[NF - 2], DL, XLenVT);
10999-
auto *Load = cast<MemIntrinsicSDNode>(Op);
11072+
case Intrinsic::riscv_seg8_load_mask:
11073+
case Intrinsic::riscv_sseg2_load_mask:
11074+
case Intrinsic::riscv_sseg3_load_mask:
11075+
case Intrinsic::riscv_sseg4_load_mask:
11076+
case Intrinsic::riscv_sseg5_load_mask:
11077+
case Intrinsic::riscv_sseg6_load_mask:
11078+
case Intrinsic::riscv_sseg7_load_mask:
11079+
case Intrinsic::riscv_sseg8_load_mask:
11080+
return convertFixedVectorSegLoadIntrinsics(IntNo, Op, Subtarget, DAG);
1100011081

11001-
SDVTList VTs = DAG.getVTList({VecTupTy, MVT::Other});
11002-
SDValue Ops[] = {
11003-
Load->getChain(),
11004-
IntID,
11005-
DAG.getUNDEF(VecTupTy),
11006-
Op.getOperand(2),
11007-
Mask,
11008-
VL,
11009-
DAG.getTargetConstant(
11010-
RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC, DL, XLenVT),
11011-
DAG.getTargetConstant(Log2_64(VT.getScalarSizeInBits()), DL, XLenVT)};
11012-
SDValue Result =
11013-
DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
11014-
Load->getMemoryVT(), Load->getMemOperand());
11015-
SmallVector<SDValue, 9> Results;
11016-
for (unsigned int RetIdx = 0; RetIdx < NF; RetIdx++) {
11017-
SDValue SubVec = DAG.getNode(RISCVISD::TUPLE_EXTRACT, DL, ContainerVT,
11018-
Result.getValue(0),
11019-
DAG.getTargetConstant(RetIdx, DL, MVT::i32));
11020-
Results.push_back(convertFromScalableVector(VT, SubVec, DAG, Subtarget));
11021-
}
11022-
Results.push_back(Result.getValue(1));
11023-
return DAG.getMergeValues(Results, DL);
11024-
}
1102511082
case Intrinsic::riscv_sf_vc_v_x_se:
1102611083
return getVCIXISDNodeWCHAIN(Op, DAG, RISCVISD::SF_VC_V_X_SE);
1102711084
case Intrinsic::riscv_sf_vc_v_i_se:
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2+
; RUN: llc -mtriple riscv64 -mattr=+zve64x,+zvl128b < %s | FileCheck %s
3+
4+
define {<8 x i8>, <8 x i8>} @load_factor2(ptr %ptr, i64 %stride) {
5+
; CHECK-LABEL: load_factor2:
6+
; CHECK: # %bb.0:
7+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
8+
; CHECK-NEXT: vlsseg2e8.v v8, (a0), a1
9+
; CHECK-NEXT: ret
10+
%1 = call { <8 x i8>, <8 x i8> } @llvm.riscv.sseg2.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
11+
ret {<8 x i8>, <8 x i8>} %1
12+
}
13+
14+
define {<8 x i8>, <8 x i8>, <8 x i8>} @load_factor3(ptr %ptr, i64 %stride) {
15+
; CHECK-LABEL: load_factor3:
16+
; CHECK: # %bb.0:
17+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
18+
; CHECK-NEXT: vlsseg3e8.v v8, (a0), a1
19+
; CHECK-NEXT: ret
20+
%1 = call { <8 x i8>, <8 x i8>, <8 x i8> } @llvm.riscv.sseg3.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
21+
ret { <8 x i8>, <8 x i8>, <8 x i8> } %1
22+
}
23+
24+
define {<8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>} @load_factor4(ptr %ptr, i64 %stride) {
25+
; CHECK-LABEL: load_factor4:
26+
; CHECK: # %bb.0:
27+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
28+
; CHECK-NEXT: vlsseg4e8.v v8, (a0), a1
29+
; CHECK-NEXT: ret
30+
%1 = call { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } @llvm.riscv.sseg4.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
31+
ret { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } %1
32+
}
33+
34+
define {<8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>} @load_factor5(ptr %ptr, i64 %stride) {
35+
; CHECK-LABEL: load_factor5:
36+
; CHECK: # %bb.0:
37+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
38+
; CHECK-NEXT: vlsseg5e8.v v8, (a0), a1
39+
; CHECK-NEXT: ret
40+
%1 = call { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } @llvm.riscv.sseg5.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
41+
ret { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } %1
42+
}
43+
44+
define {<8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>} @load_factor6(ptr %ptr, i64 %stride) {
45+
; CHECK-LABEL: load_factor6:
46+
; CHECK: # %bb.0:
47+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
48+
; CHECK-NEXT: vlsseg6e8.v v8, (a0), a1
49+
; CHECK-NEXT: ret
50+
%1 = call { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } @llvm.riscv.sseg6.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
51+
ret { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } %1
52+
}
53+
54+
define {<8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>} @load_factor7(ptr %ptr, i64 %stride) {
55+
; CHECK-LABEL: load_factor7:
56+
; CHECK: # %bb.0:
57+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
58+
; CHECK-NEXT: vlsseg7e8.v v8, (a0), a1
59+
; CHECK-NEXT: ret
60+
%1 = call { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } @llvm.riscv.sseg7.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
61+
ret { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } %1
62+
}
63+
64+
define {<8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>} @load_factor8(ptr %ptr, i64 %stride) {
65+
; CHECK-LABEL: load_factor8:
66+
; CHECK: # %bb.0:
67+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
68+
; CHECK-NEXT: vlsseg8e8.v v8, (a0), a1
69+
; CHECK-NEXT: ret
70+
%1 = call { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } @llvm.riscv.sseg8.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
71+
ret { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } %1
72+
}

0 commit comments

Comments
 (0)