Skip to content

Commit 3e5fafd

Browse files
authored
[RISCV][llvm] Select splat_vector(constant) with PLI (#168204)
Default DAG combiner combine BUILD_VECTOR with same elements to SPLAT_VECTOR, we can just map constant splat to PLI if possible.
1 parent fde2aad commit 3e5fafd

File tree

5 files changed

+55
-41
lines changed

5 files changed

+55
-41
lines changed

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ void RISCVDAGToDAGISel::PreprocessISelDAG() {
5151
SDValue Result;
5252
switch (N->getOpcode()) {
5353
case ISD::SPLAT_VECTOR: {
54+
if (Subtarget->enablePExtCodeGen())
55+
break;
5456
// Convert integer SPLAT_VECTOR to VMV_V_X_VL and floating-point
5557
// SPLAT_VECTOR to VFMV_V_F_VL to reduce isel burden.
5658
MVT VT = N->getSimpleValueType(0);

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
526526
setOperationAction(ISD::SSUBSAT, VTs, Legal);
527527
setOperationAction({ISD::AVGFLOORS, ISD::AVGFLOORU}, VTs, Legal);
528528
setOperationAction({ISD::ABDS, ISD::ABDU}, VTs, Legal);
529-
setOperationAction(ISD::BUILD_VECTOR, VTs, Custom);
529+
setOperationAction(ISD::SPLAT_VECTOR, VTs, Legal);
530530
setOperationAction(ISD::BITCAST, VTs, Custom);
531531
setOperationAction(ISD::EXTRACT_VECTOR_ELT, VTs, Custom);
532532
}
@@ -4433,37 +4433,6 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
44334433
MVT XLenVT = Subtarget.getXLenVT();
44344434

44354435
SDLoc DL(Op);
4436-
// Handle P extension packed vector BUILD_VECTOR with PLI for splat constants
4437-
if (Subtarget.enablePExtCodeGen()) {
4438-
bool IsPExtVector =
4439-
(VT == MVT::v2i16 || VT == MVT::v4i8) ||
4440-
(Subtarget.is64Bit() &&
4441-
(VT == MVT::v4i16 || VT == MVT::v8i8 || VT == MVT::v2i32));
4442-
if (IsPExtVector) {
4443-
if (SDValue SplatValue = cast<BuildVectorSDNode>(Op)->getSplatValue()) {
4444-
if (auto *C = dyn_cast<ConstantSDNode>(SplatValue)) {
4445-
int64_t SplatImm = C->getSExtValue();
4446-
bool IsValidImm = false;
4447-
4448-
// Check immediate range based on vector type
4449-
if (VT == MVT::v8i8 || VT == MVT::v4i8) {
4450-
// PLI_B uses 8-bit unsigned or unsigned immediate
4451-
IsValidImm = isUInt<8>(SplatImm) || isInt<8>(SplatImm);
4452-
if (isUInt<8>(SplatImm))
4453-
SplatImm = (int8_t)SplatImm;
4454-
} else {
4455-
// PLI_H and PLI_W use 10-bit signed immediate
4456-
IsValidImm = isInt<10>(SplatImm);
4457-
}
4458-
4459-
if (IsValidImm) {
4460-
SDValue Imm = DAG.getSignedTargetConstant(SplatImm, DL, XLenVT);
4461-
return DAG.getNode(RISCVISD::PLI, DL, VT, Imm);
4462-
}
4463-
}
4464-
}
4465-
}
4466-
}
44674436

44684437
// Proper support for f16 requires Zvfh. bf16 always requires special
44694438
// handling. We need to cast the scalar to integer and create an integer

llvm/lib/Target/RISCV/RISCVInstrInfoP.td

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@
1818
// Operand and SDNode transformation definitions.
1919
//===----------------------------------------------------------------------===//
2020

21-
def simm10 : RISCVSImmOp<10>, TImmLeaf<XLenVT, "return isInt<10>(Imm);">;
21+
def simm10 : RISCVSImmOp<10>, ImmLeaf<XLenVT, "return isInt<10>(Imm);">;
2222

2323
def SImm8UnsignedAsmOperand : SImmAsmOperand<8, "Unsigned"> {
2424
let RenderMethod = "addSImm8UnsignedOperands";
2525
}
2626

2727
// A 8-bit signed immediate allowing range [-128, 255]
2828
// but represented as [-128, 127].
29-
def simm8_unsigned : RISCVOp, TImmLeaf<XLenVT, "return isInt<8>(Imm);"> {
29+
def simm8_unsigned : RISCVOp, ImmLeaf<XLenVT, "return isInt<8>(Imm);"> {
3030
let ParserMatchClass = SImm8UnsignedAsmOperand;
3131
let EncoderMethod = "getImmOpValue";
3232
let DecoderMethod = "decodeSImmOperand<8>";
@@ -1463,10 +1463,6 @@ let Predicates = [HasStdExtP, IsRV32] in {
14631463

14641464
def riscv_absw : RVSDNode<"ABSW", SDTIntUnaryOp>;
14651465

1466-
def SDT_RISCVPLI : SDTypeProfile<1, 1, [SDTCisVec<0>,
1467-
SDTCisInt<0>,
1468-
SDTCisInt<1>]>;
1469-
def riscv_pli : RVSDNode<"PLI", SDT_RISCVPLI>;
14701466
def SDT_RISCVPASUB : SDTypeProfile<1, 2, [SDTCisVec<0>,
14711467
SDTCisInt<0>,
14721468
SDTCisSameAs<0, 1>,
@@ -1519,10 +1515,13 @@ let Predicates = [HasStdExtP] in {
15191515

15201516

15211517
// 8-bit PLI SD node pattern
1522-
def: Pat<(XLenVecI8VT (riscv_pli simm8_unsigned:$imm8)), (PLI_B simm8_unsigned:$imm8)>;
1518+
def: Pat<(XLenVecI8VT (splat_vector simm8_unsigned:$imm8)), (PLI_B simm8_unsigned:$imm8)>;
15231519
// 16-bit PLI SD node pattern
1524-
def: Pat<(XLenVecI16VT (riscv_pli simm10:$imm10)), (PLI_H simm10:$imm10)>;
1520+
def: Pat<(XLenVecI16VT (splat_vector simm10:$imm10)), (PLI_H simm10:$imm10)>;
15251521

1522+
// // splat pattern
1523+
def: Pat<(XLenVecI8VT (splat_vector (XLenVT GPR:$rs2))), (PADD_BS (XLenVT X0), GPR:$rs2)>;
1524+
def: Pat<(XLenVecI16VT (splat_vector (XLenVT GPR:$rs2))), (PADD_HS (XLenVT X0), GPR:$rs2)>;
15261525
} // Predicates = [HasStdExtP]
15271526

15281527
let Predicates = [HasStdExtP, IsRV32] in {
@@ -1537,7 +1536,7 @@ let Predicates = [HasStdExtP, IsRV64] in {
15371536
def : PatGpr<riscv_absw, ABSW>;
15381537

15391538
// 32-bit PLI SD node pattern
1540-
def: Pat<(v2i32 (riscv_pli simm10:$imm10)), (PLI_W simm10:$imm10)>;
1539+
def: Pat<(v2i32 (splat_vector simm10:$imm10)), (PLI_W simm10:$imm10)>;
15411540

15421541
// Basic 32-bit arithmetic patterns
15431542
def: Pat<(v2i32 (add GPR:$rs1, GPR:$rs2)), (PADD_W GPR:$rs1, GPR:$rs2)>;
@@ -1557,6 +1556,9 @@ let Predicates = [HasStdExtP, IsRV64] in {
15571556
def: Pat<(v2i32 (riscv_pasub GPR:$rs1, GPR:$rs2)), (PASUB_W GPR:$rs1, GPR:$rs2)>;
15581557
def: Pat<(v2i32 (riscv_pasubu GPR:$rs1, GPR:$rs2)), (PASUBU_W GPR:$rs1, GPR:$rs2)>;
15591558

1559+
// splat pattern
1560+
def: Pat<(v2i32 (splat_vector (XLenVT GPR:$rs2))), (PADD_WS (XLenVT X0), GPR:$rs2)>;
1561+
15601562
// Load/Store patterns
15611563
def : StPat<store, SD, GPR, v8i8>;
15621564
def : StPat<store, SD, GPR, v4i16>;

llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,33 @@ define void @test_extract_vector_8(ptr %ret_ptr, ptr %a_ptr) {
496496
ret void
497497
}
498498

499+
; Test for splat
500+
define void @test_non_const_splat_i8(ptr %ret_ptr, ptr %a_ptr, i8 %elt) {
501+
; CHECK-LABEL: test_non_const_splat_i8:
502+
; CHECK: # %bb.0:
503+
; CHECK-NEXT: padd.bs a1, zero, a2
504+
; CHECK-NEXT: sw a1, 0(a0)
505+
; CHECK-NEXT: ret
506+
%a = load <4 x i8>, ptr %a_ptr
507+
%insert = insertelement <4 x i8> poison, i8 %elt, i32 0
508+
%splat = shufflevector <4 x i8> %insert, <4 x i8> poison, <4 x i32> zeroinitializer
509+
store <4 x i8> %splat, ptr %ret_ptr
510+
ret void
511+
}
512+
513+
define void @test_non_const_splat_i16(ptr %ret_ptr, ptr %a_ptr, i16 %elt) {
514+
; CHECK-LABEL: test_non_const_splat_i16:
515+
; CHECK: # %bb.0:
516+
; CHECK-NEXT: padd.hs a1, zero, a2
517+
; CHECK-NEXT: sw a1, 0(a0)
518+
; CHECK-NEXT: ret
519+
%a = load <2 x i16>, ptr %a_ptr
520+
%insert = insertelement <2 x i16> poison, i16 %elt, i32 0
521+
%splat = shufflevector <2 x i16> %insert, <2 x i16> poison, <2 x i32> zeroinitializer
522+
store <2 x i16> %splat, ptr %ret_ptr
523+
ret void
524+
}
525+
499526
; Intrinsic declarations
500527
declare <2 x i16> @llvm.sadd.sat.v2i16(<2 x i16>, <2 x i16>)
501528
declare <2 x i16> @llvm.uadd.sat.v2i16(<2 x i16>, <2 x i16>)

llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,20 @@ define void @test_pasubu_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
671671
ret void
672672
}
673673

674+
; Test for splat
675+
define void @test_non_const_splat_i32(ptr %ret_ptr, ptr %a_ptr, i32 %elt) {
676+
; CHECK-LABEL: test_non_const_splat_i32:
677+
; CHECK: # %bb.0:
678+
; CHECK-NEXT: padd.ws a1, zero, a2
679+
; CHECK-NEXT: sd a1, 0(a0)
680+
; CHECK-NEXT: ret
681+
%a = load <2 x i32>, ptr %a_ptr
682+
%insert = insertelement <2 x i32> poison, i32 %elt, i32 0
683+
%splat = shufflevector <2 x i32> %insert, <2 x i32> poison, <2 x i32> zeroinitializer
684+
store <2 x i32> %splat, ptr %ret_ptr
685+
ret void
686+
}
687+
674688
; Intrinsic declarations
675689
declare <4 x i16> @llvm.sadd.sat.v4i16(<4 x i16>, <4 x i16>)
676690
declare <4 x i16> @llvm.uadd.sat.v4i16(<4 x i16>, <4 x i16>)

0 commit comments

Comments
 (0)