Skip to content

Commit bf4b7f1

Browse files
committed
[RISCV] Add intrinsics for strided segment stores with fixed vectors
1 parent 6e94dc3 commit bf4b7f1

File tree

3 files changed

+192
-50
lines changed

3 files changed

+192
-50
lines changed

llvm/include/llvm/IR/IntrinsicsRISCV.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1736,6 +1736,17 @@ let TargetPrefix = "riscv" in {
17361736
[llvm_anyptr_ty, LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
17371737
llvm_anyint_ty]),
17381738
[NoCapture<ArgIndex<nf>>, IntrWriteMem]>;
1739+
1740+
// Input: (<stored values>..., pointer, offset, mask, vl)
1741+
def int_riscv_sseg # nf # _store_mask
1742+
: DefaultAttrsIntrinsic<[],
1743+
!listconcat([llvm_anyvector_ty],
1744+
!listsplat(LLVMMatchType<0>,
1745+
!add(nf, -1)),
1746+
[llvm_anyptr_ty, llvm_anyint_ty,
1747+
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
1748+
llvm_anyint_ty]),
1749+
[NoCapture<ArgIndex<nf>>, IntrWriteMem]>;
17391750
}
17401751

17411752
} // TargetPrefix = "riscv"

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 109 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1839,6 +1839,17 @@ bool RISCVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
18391839
return SetRVVLoadStoreInfo(/*PtrOp*/ I.arg_size() - 3,
18401840
/*IsStore*/ true,
18411841
/*IsUnitStrided*/ false, /*UsePtrVal*/ true);
1842+
case Intrinsic::riscv_sseg2_store_mask:
1843+
case Intrinsic::riscv_sseg3_store_mask:
1844+
case Intrinsic::riscv_sseg4_store_mask:
1845+
case Intrinsic::riscv_sseg5_store_mask:
1846+
case Intrinsic::riscv_sseg6_store_mask:
1847+
case Intrinsic::riscv_sseg7_store_mask:
1848+
case Intrinsic::riscv_sseg8_store_mask:
1849+
// Operands are (vec, ..., vec, ptr, offset, mask, vl)
1850+
return SetRVVLoadStoreInfo(/*PtrOp*/ I.arg_size() - 4,
1851+
/*IsStore*/ true,
1852+
/*IsUnitStrided*/ false, /*UsePtrVal*/ true);
18421853
case Intrinsic::riscv_vlm:
18431854
return SetRVVLoadStoreInfo(/*PtrOp*/ 0,
18441855
/*IsStore*/ false,
@@ -11077,69 +11088,117 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
1107711088
return lowerVectorIntrinsicScalars(Op, DAG, Subtarget);
1107811089
}
1107911090

11080-
SDValue RISCVTargetLowering::LowerINTRINSIC_VOID(SDValue Op,
11081-
SelectionDAG &DAG) const {
11082-
unsigned IntNo = Op.getConstantOperandVal(1);
11091+
static SDValue
11092+
lowerFixedVectorSegStoreIntrinsics(unsigned IntNo, SDValue Op,
11093+
const RISCVSubtarget &Subtarget,
11094+
SelectionDAG &DAG) {
11095+
bool IsStrided;
1108311096
switch (IntNo) {
11084-
default:
11085-
break;
1108611097
case Intrinsic::riscv_seg2_store_mask:
1108711098
case Intrinsic::riscv_seg3_store_mask:
1108811099
case Intrinsic::riscv_seg4_store_mask:
1108911100
case Intrinsic::riscv_seg5_store_mask:
1109011101
case Intrinsic::riscv_seg6_store_mask:
1109111102
case Intrinsic::riscv_seg7_store_mask:
11092-
case Intrinsic::riscv_seg8_store_mask: {
11093-
SDLoc DL(Op);
11094-
static const Intrinsic::ID VssegInts[] = {
11095-
Intrinsic::riscv_vsseg2_mask, Intrinsic::riscv_vsseg3_mask,
11096-
Intrinsic::riscv_vsseg4_mask, Intrinsic::riscv_vsseg5_mask,
11097-
Intrinsic::riscv_vsseg6_mask, Intrinsic::riscv_vsseg7_mask,
11098-
Intrinsic::riscv_vsseg8_mask};
11103+
case Intrinsic::riscv_seg8_store_mask:
11104+
IsStrided = false;
11105+
break;
11106+
case Intrinsic::riscv_sseg2_store_mask:
11107+
case Intrinsic::riscv_sseg3_store_mask:
11108+
case Intrinsic::riscv_sseg4_store_mask:
11109+
case Intrinsic::riscv_sseg5_store_mask:
11110+
case Intrinsic::riscv_sseg6_store_mask:
11111+
case Intrinsic::riscv_sseg7_store_mask:
11112+
case Intrinsic::riscv_sseg8_store_mask:
11113+
IsStrided = true;
11114+
break;
11115+
default:
11116+
llvm_unreachable("unexpected intrinsic ID");
11117+
}
1109911118

11100-
// Operands: (chain, int_id, vec*, ptr, mask, vl)
11101-
unsigned NF = Op->getNumOperands() - 5;
11102-
assert(NF >= 2 && NF <= 8 && "Unexpected seg number");
11103-
MVT XLenVT = Subtarget.getXLenVT();
11104-
MVT VT = Op->getOperand(2).getSimpleValueType();
11105-
MVT ContainerVT = getContainerForFixedLengthVector(VT);
11106-
unsigned Sz = NF * ContainerVT.getVectorMinNumElements() *
11107-
ContainerVT.getScalarSizeInBits();
11108-
EVT VecTupTy = MVT::getRISCVVectorTupleVT(Sz, NF);
11119+
SDLoc DL(Op);
11120+
static const Intrinsic::ID VssegInts[] = {
11121+
Intrinsic::riscv_vsseg2_mask, Intrinsic::riscv_vsseg3_mask,
11122+
Intrinsic::riscv_vsseg4_mask, Intrinsic::riscv_vsseg5_mask,
11123+
Intrinsic::riscv_vsseg6_mask, Intrinsic::riscv_vsseg7_mask,
11124+
Intrinsic::riscv_vsseg8_mask};
11125+
static const Intrinsic::ID VsssegInts[] = {
11126+
Intrinsic::riscv_vssseg2_mask, Intrinsic::riscv_vssseg3_mask,
11127+
Intrinsic::riscv_vssseg4_mask, Intrinsic::riscv_vssseg5_mask,
11128+
Intrinsic::riscv_vssseg6_mask, Intrinsic::riscv_vssseg7_mask,
11129+
Intrinsic::riscv_vssseg8_mask};
11130+
11131+
// Operands: (chain, int_id, vec*, ptr, mask, vl) or
11132+
// (chain, int_id, vec*, ptr, stride, mask, vl)
11133+
unsigned NF = Op->getNumOperands() - (IsStrided ? 6 : 5);
11134+
assert(NF >= 2 && NF <= 8 && "Unexpected seg number");
11135+
MVT XLenVT = Subtarget.getXLenVT();
11136+
MVT VT = Op->getOperand(2).getSimpleValueType();
11137+
MVT ContainerVT = ::getContainerForFixedLengthVector(DAG, VT, Subtarget);
11138+
unsigned Sz = NF * ContainerVT.getVectorMinNumElements() *
11139+
ContainerVT.getScalarSizeInBits();
11140+
EVT VecTupTy = MVT::getRISCVVectorTupleVT(Sz, NF);
11141+
11142+
SDValue VL = Op.getOperand(Op.getNumOperands() - 1);
11143+
SDValue Mask = Op.getOperand(Op.getNumOperands() - 2);
11144+
MVT MaskVT = Mask.getSimpleValueType();
11145+
MVT MaskContainerVT =
11146+
::getContainerForFixedLengthVector(DAG, MaskVT, Subtarget);
11147+
Mask = convertToScalableVector(MaskContainerVT, Mask, DAG, Subtarget);
1110911148

11110-
SDValue VL = Op.getOperand(Op.getNumOperands() - 1);
11111-
SDValue Mask = Op.getOperand(Op.getNumOperands() - 2);
11112-
MVT MaskVT = Mask.getSimpleValueType();
11113-
MVT MaskContainerVT =
11114-
::getContainerForFixedLengthVector(DAG, MaskVT, Subtarget);
11115-
Mask = convertToScalableVector(MaskContainerVT, Mask, DAG, Subtarget);
11149+
SDValue IntID = DAG.getTargetConstant(
11150+
IsStrided ? VsssegInts[NF - 2] : VssegInts[NF - 2], DL, XLenVT);
11151+
SDValue Ptr = Op->getOperand(NF + 2);
1111611152

11117-
SDValue IntID = DAG.getTargetConstant(VssegInts[NF - 2], DL, XLenVT);
11118-
SDValue Ptr = Op->getOperand(NF + 2);
11153+
auto *FixedIntrinsic = cast<MemIntrinsicSDNode>(Op);
1111911154

11120-
auto *FixedIntrinsic = cast<MemIntrinsicSDNode>(Op);
11155+
SDValue StoredVal = DAG.getUNDEF(VecTupTy);
11156+
for (unsigned i = 0; i < NF; i++)
11157+
StoredVal = DAG.getNode(
11158+
RISCVISD::TUPLE_INSERT, DL, VecTupTy, StoredVal,
11159+
convertToScalableVector(ContainerVT, FixedIntrinsic->getOperand(2 + i),
11160+
DAG, Subtarget),
11161+
DAG.getTargetConstant(i, DL, MVT::i32));
1112111162

11122-
SDValue StoredVal = DAG.getUNDEF(VecTupTy);
11123-
for (unsigned i = 0; i < NF; i++)
11124-
StoredVal = DAG.getNode(
11125-
RISCVISD::TUPLE_INSERT, DL, VecTupTy, StoredVal,
11126-
convertToScalableVector(
11127-
ContainerVT, FixedIntrinsic->getOperand(2 + i), DAG, Subtarget),
11128-
DAG.getTargetConstant(i, DL, MVT::i32));
11163+
SmallVector<SDValue, 10> Ops = {
11164+
FixedIntrinsic->getChain(),
11165+
IntID,
11166+
StoredVal,
11167+
Ptr,
11168+
Mask,
11169+
VL,
11170+
DAG.getTargetConstant(Log2_64(VT.getScalarSizeInBits()), DL, XLenVT)};
11171+
if (IsStrided)
11172+
Ops.insert(std::next(Ops.begin(), 4),
11173+
Op.getOperand(Op.getNumOperands() - 3));
11174+
11175+
return DAG.getMemIntrinsicNode(
11176+
ISD::INTRINSIC_VOID, DL, DAG.getVTList(MVT::Other), Ops,
11177+
FixedIntrinsic->getMemoryVT(), FixedIntrinsic->getMemOperand());
11178+
}
11179+
11180+
SDValue RISCVTargetLowering::LowerINTRINSIC_VOID(SDValue Op,
11181+
SelectionDAG &DAG) const {
11182+
unsigned IntNo = Op.getConstantOperandVal(1);
11183+
switch (IntNo) {
11184+
default:
11185+
break;
11186+
case Intrinsic::riscv_seg2_store_mask:
11187+
case Intrinsic::riscv_seg3_store_mask:
11188+
case Intrinsic::riscv_seg4_store_mask:
11189+
case Intrinsic::riscv_seg5_store_mask:
11190+
case Intrinsic::riscv_seg6_store_mask:
11191+
case Intrinsic::riscv_seg7_store_mask:
11192+
case Intrinsic::riscv_seg8_store_mask:
11193+
case Intrinsic::riscv_sseg2_store_mask:
11194+
case Intrinsic::riscv_sseg3_store_mask:
11195+
case Intrinsic::riscv_sseg4_store_mask:
11196+
case Intrinsic::riscv_sseg5_store_mask:
11197+
case Intrinsic::riscv_sseg6_store_mask:
11198+
case Intrinsic::riscv_sseg7_store_mask:
11199+
case Intrinsic::riscv_sseg8_store_mask:
11200+
return lowerFixedVectorSegStoreIntrinsics(IntNo, Op, Subtarget, DAG);
1112911201

11130-
SDValue Ops[] = {
11131-
FixedIntrinsic->getChain(),
11132-
IntID,
11133-
StoredVal,
11134-
Ptr,
11135-
Mask,
11136-
VL,
11137-
DAG.getTargetConstant(Log2_64(VT.getScalarSizeInBits()), DL, XLenVT)};
11138-
11139-
return DAG.getMemIntrinsicNode(
11140-
ISD::INTRINSIC_VOID, DL, DAG.getVTList(MVT::Other), Ops,
11141-
FixedIntrinsic->getMemoryVT(), FixedIntrinsic->getMemOperand());
11142-
}
1114311202
case Intrinsic::riscv_sf_vc_xv_se:
1114411203
return getVCIXISDNodeVOID(Op, DAG, RISCVISD::SF_VC_XV_SE);
1114511204
case Intrinsic::riscv_sf_vc_iv_se:
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s
3+
4+
define void @store_factor2(<8 x i8> %v0, <8 x i8> %v1, ptr %ptr, i64 %stride) {
5+
; CHECK-LABEL: store_factor2:
6+
; CHECK: # %bb.0:
7+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
8+
; CHECK-NEXT: vssseg2e8.v v8, (a0), a1
9+
; CHECK-NEXT: ret
10+
call void @llvm.riscv.sseg2.store.mask.v8i8.i64.i64(<8 x i8> %v0, <8 x i8> %v1, ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
11+
ret void
12+
}
13+
14+
define void @store_factor3(<8 x i8> %v0, <8 x i8> %v1, <8 x i8> %v2, ptr %ptr, i64 %stride) {
15+
; CHECK-LABEL: store_factor3:
16+
; CHECK: # %bb.0:
17+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
18+
; CHECK-NEXT: vssseg3e8.v v8, (a0), a1
19+
; CHECK-NEXT: ret
20+
call void @llvm.riscv.sseg3.store.mask.v8i8.i64.i64(<8 x i8> %v0, <8 x i8> %v1, <8 x i8> %v2, ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
21+
ret void
22+
}
23+
24+
define void @store_factor4(<8 x i8> %v0, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3, ptr %ptr, i64 %stride) {
25+
; CHECK-LABEL: store_factor4:
26+
; CHECK: # %bb.0:
27+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
28+
; CHECK-NEXT: vssseg4e8.v v8, (a0), a1
29+
; CHECK-NEXT: ret
30+
call void @llvm.riscv.sseg4.store.mask.v8i8.i64.i64(<8 x i8> %v0, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3, ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
31+
ret void
32+
}
33+
34+
define void @store_factor5(<8 x i8> %v0, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3, <8 x i8> %v4, ptr %ptr, i64 %stride) {
35+
; CHECK-LABEL: store_factor5:
36+
; CHECK: # %bb.0:
37+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
38+
; CHECK-NEXT: vssseg5e8.v v8, (a0), a1
39+
; CHECK-NEXT: ret
40+
call void @llvm.riscv.sseg5.store.mask.v8i8.i64.i64(<8 x i8> %v0, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3, <8 x i8> %v4, ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
41+
ret void
42+
}
43+
44+
define void @store_factor6(<8 x i8> %v0, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3, <8 x i8> %v4, <8 x i8> %v5, ptr %ptr, i64 %stride) {
45+
; CHECK-LABEL: store_factor6:
46+
; CHECK: # %bb.0:
47+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
48+
; CHECK-NEXT: vssseg6e8.v v8, (a0), a1
49+
; CHECK-NEXT: ret
50+
call void @llvm.riscv.sseg6.store.mask.v8i8.i64.i64(<8 x i8> %v0, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3, <8 x i8> %v4, <8 x i8> %v5, ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
51+
ret void
52+
}
53+
54+
define void @store_factor7(<8 x i8> %v0, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3, <8 x i8> %v4, <8 x i8> %v5, <8 x i8> %v6, ptr %ptr, i64 %stride) {
55+
; CHECK-LABEL: store_factor7:
56+
; CHECK: # %bb.0:
57+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
58+
; CHECK-NEXT: vssseg7e8.v v8, (a0), a1
59+
; CHECK-NEXT: ret
60+
call void @llvm.riscv.sseg7.store.mask.v8i8.i64.i64(<8 x i8> %v0, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3, <8 x i8> %v4, <8 x i8> %v5, <8 x i8> %v6, ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
61+
ret void
62+
}
63+
64+
define void @store_factor8(<8 x i8> %v0, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3, <8 x i8> %v4, <8 x i8> %v5, <8 x i8> %v6, <8 x i8> %v7, ptr %ptr, i64 %stride) {
65+
; CHECK-LABEL: store_factor8:
66+
; CHECK: # %bb.0:
67+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
68+
; CHECK-NEXT: vssseg8e8.v v8, (a0), a1
69+
; CHECK-NEXT: ret
70+
call void @llvm.riscv.sseg8.store.mask.v8i8.i64.i64(<8 x i8> %v0, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3, <8 x i8> %v4, <8 x i8> %v5, <8 x i8> %v6, <8 x i8> %v7, ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
71+
ret void
72+
}

0 commit comments

Comments
 (0)