From 142db52bbbc33f608317dbd84a1650f2a0b356b9 Mon Sep 17 00:00:00 2001 From: Brandon Wu Date: Sat, 15 Nov 2025 04:18:35 -0800 Subject: [PATCH 1/2] [RISCV][llvm] Select splat_vector(constant) with PLI Default DAG combiner combine BUILD_VECTOR with same elements to SPLAT_VECTOR, we can just map constant splat to PLI if possible. --- llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp | 2 ++ llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 33 +-------------------- llvm/lib/Target/RISCV/RISCVInstrInfoP.td | 14 ++++----- 3 files changed, 8 insertions(+), 41 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp index 1024e55f912c7..5025122db3681 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -51,6 +51,8 @@ void RISCVDAGToDAGISel::PreprocessISelDAG() { SDValue Result; switch (N->getOpcode()) { case ISD::SPLAT_VECTOR: { + if (Subtarget->enablePExtCodeGen()) + break; // Convert integer SPLAT_VECTOR to VMV_V_X_VL and floating-point // SPLAT_VECTOR to VFMV_V_F_VL to reduce isel burden. MVT VT = N->getSimpleValueType(0); diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 38cce26e44af4..37c8d7c045443 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -525,7 +525,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(ISD::SSUBSAT, VTs, Legal); setOperationAction({ISD::AVGFLOORS, ISD::AVGFLOORU}, VTs, Legal); setOperationAction({ISD::ABDS, ISD::ABDU}, VTs, Legal); - setOperationAction(ISD::BUILD_VECTOR, VTs, Custom); + setOperationAction(ISD::SPLAT_VECTOR, VTs, Legal); setOperationAction(ISD::BITCAST, VTs, Custom); setOperationAction(ISD::EXTRACT_VECTOR_ELT, VTs, Custom); } @@ -4437,37 +4437,6 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, MVT XLenVT = Subtarget.getXLenVT(); SDLoc DL(Op); - // Handle P extension packed vector BUILD_VECTOR with PLI for splat constants - if (Subtarget.enablePExtCodeGen()) { - bool IsPExtVector = - (VT == MVT::v2i16 || VT == MVT::v4i8) || - (Subtarget.is64Bit() && - (VT == MVT::v4i16 || VT == MVT::v8i8 || VT == MVT::v2i32)); - if (IsPExtVector) { - if (SDValue SplatValue = cast(Op)->getSplatValue()) { - if (auto *C = dyn_cast(SplatValue)) { - int64_t SplatImm = C->getSExtValue(); - bool IsValidImm = false; - - // Check immediate range based on vector type - if (VT == MVT::v8i8 || VT == MVT::v4i8) { - // PLI_B uses 8-bit unsigned or unsigned immediate - IsValidImm = isUInt<8>(SplatImm) || isInt<8>(SplatImm); - if (isUInt<8>(SplatImm)) - SplatImm = (int8_t)SplatImm; - } else { - // PLI_H and PLI_W use 10-bit signed immediate - IsValidImm = isInt<10>(SplatImm); - } - - if (IsValidImm) { - SDValue Imm = DAG.getSignedTargetConstant(SplatImm, DL, XLenVT); - return DAG.getNode(RISCVISD::PLI, DL, VT, Imm); - } - } - } - } - } // Proper support for f16 requires Zvfh. bf16 always requires special // handling. We need to cast the scalar to integer and create an integer diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td index 126a39996c741..2f289f89e8859 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td @@ -18,7 +18,7 @@ // Operand and SDNode transformation definitions. //===----------------------------------------------------------------------===// -def simm10 : RISCVSImmOp<10>, TImmLeaf(Imm);">; +def simm10 : RISCVSImmOp<10>, ImmLeaf(Imm);">; def SImm8UnsignedAsmOperand : SImmAsmOperand<8, "Unsigned"> { let RenderMethod = "addSImm8UnsignedOperands"; @@ -26,7 +26,7 @@ def SImm8UnsignedAsmOperand : SImmAsmOperand<8, "Unsigned"> { // A 8-bit signed immediate allowing range [-128, 255] // but represented as [-128, 127]. -def simm8_unsigned : RISCVOp, TImmLeaf(Imm);"> { +def simm8_unsigned : RISCVOp, ImmLeaf(Imm);"> { let ParserMatchClass = SImm8UnsignedAsmOperand; let EncoderMethod = "getImmOpValue"; let DecoderMethod = "decodeSImmOperand<8>"; @@ -1463,10 +1463,6 @@ let Predicates = [HasStdExtP, IsRV32] in { def riscv_absw : RVSDNode<"ABSW", SDTIntUnaryOp>; -def SDT_RISCVPLI : SDTypeProfile<1, 1, [SDTCisVec<0>, - SDTCisInt<0>, - SDTCisInt<1>]>; -def riscv_pli : RVSDNode<"PLI", SDT_RISCVPLI>; def SDT_RISCVPASUB : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisInt<0>, SDTCisSameAs<0, 1>, @@ -1519,9 +1515,9 @@ let Predicates = [HasStdExtP] in { // 8-bit PLI SD node pattern - def: Pat<(XLenVecI8VT (riscv_pli simm8_unsigned:$imm8)), (PLI_B simm8_unsigned:$imm8)>; + def: Pat<(XLenVecI8VT (splat_vector simm8_unsigned:$imm8)), (PLI_B simm8_unsigned:$imm8)>; // 16-bit PLI SD node pattern - def: Pat<(XLenVecI16VT (riscv_pli simm10:$imm10)), (PLI_H simm10:$imm10)>; + def: Pat<(XLenVecI16VT (splat_vector simm10:$imm10)), (PLI_H simm10:$imm10)>; } // Predicates = [HasStdExtP] @@ -1537,7 +1533,7 @@ let Predicates = [HasStdExtP, IsRV64] in { def : PatGpr; // 32-bit PLI SD node pattern - def: Pat<(v2i32 (riscv_pli simm10:$imm10)), (PLI_W simm10:$imm10)>; + def: Pat<(v2i32 (splat_vector simm10:$imm10)), (PLI_W simm10:$imm10)>; // Basic 32-bit arithmetic patterns def: Pat<(v2i32 (add GPR:$rs1, GPR:$rs2)), (PADD_W GPR:$rs1, GPR:$rs2)>; From b1843c75799188d754ca56796ecb83c72b48cb13 Mon Sep 17 00:00:00 2001 From: Brandon Wu Date: Mon, 17 Nov 2025 00:02:36 -0800 Subject: [PATCH 2/2] fixup! Add non-const splat pattern --- llvm/lib/Target/RISCV/RISCVInstrInfoP.td | 6 ++++++ llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll | 27 ++++++++++++++++++++++++ llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll | 14 ++++++++++++ 3 files changed, 47 insertions(+) diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td index 2f289f89e8859..764e3c9c58355 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td @@ -1519,6 +1519,9 @@ let Predicates = [HasStdExtP] in { // 16-bit PLI SD node pattern def: Pat<(XLenVecI16VT (splat_vector simm10:$imm10)), (PLI_H simm10:$imm10)>; + // // splat pattern + def: Pat<(XLenVecI8VT (splat_vector (XLenVT GPR:$rs2))), (PADD_BS (XLenVT X0), GPR:$rs2)>; + def: Pat<(XLenVecI16VT (splat_vector (XLenVT GPR:$rs2))), (PADD_HS (XLenVT X0), GPR:$rs2)>; } // Predicates = [HasStdExtP] let Predicates = [HasStdExtP, IsRV32] in { @@ -1553,6 +1556,9 @@ let Predicates = [HasStdExtP, IsRV64] in { def: Pat<(v2i32 (riscv_pasub GPR:$rs1, GPR:$rs2)), (PASUB_W GPR:$rs1, GPR:$rs2)>; def: Pat<(v2i32 (riscv_pasubu GPR:$rs1, GPR:$rs2)), (PASUBU_W GPR:$rs1, GPR:$rs2)>; + // splat pattern + def: Pat<(v2i32 (splat_vector (XLenVT GPR:$rs2))), (PADD_WS (XLenVT X0), GPR:$rs2)>; + // Load/Store patterns def : StPat; def : StPat; diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll index 46d5e9f9a538f..bb3e691311cd8 100644 --- a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll +++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll @@ -496,6 +496,33 @@ define void @test_extract_vector_8(ptr %ret_ptr, ptr %a_ptr) { ret void } +; Test for splat +define void @test_non_const_splat_i8(ptr %ret_ptr, ptr %a_ptr, i8 %elt) { +; CHECK-LABEL: test_non_const_splat_i8: +; CHECK: # %bb.0: +; CHECK-NEXT: padd.bs a1, zero, a2 +; CHECK-NEXT: sw a1, 0(a0) +; CHECK-NEXT: ret + %a = load <4 x i8>, ptr %a_ptr + %insert = insertelement <4 x i8> poison, i8 %elt, i32 0 + %splat = shufflevector <4 x i8> %insert, <4 x i8> poison, <4 x i32> zeroinitializer + store <4 x i8> %splat, ptr %ret_ptr + ret void +} + +define void @test_non_const_splat_i16(ptr %ret_ptr, ptr %a_ptr, i16 %elt) { +; CHECK-LABEL: test_non_const_splat_i16: +; CHECK: # %bb.0: +; CHECK-NEXT: padd.hs a1, zero, a2 +; CHECK-NEXT: sw a1, 0(a0) +; CHECK-NEXT: ret + %a = load <2 x i16>, ptr %a_ptr + %insert = insertelement <2 x i16> poison, i16 %elt, i32 0 + %splat = shufflevector <2 x i16> %insert, <2 x i16> poison, <2 x i32> zeroinitializer + store <2 x i16> %splat, ptr %ret_ptr + ret void +} + ; Intrinsic declarations declare <2 x i16> @llvm.sadd.sat.v2i16(<2 x i16>, <2 x i16>) declare <2 x i16> @llvm.uadd.sat.v2i16(<2 x i16>, <2 x i16>) diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll index 353039e9482e9..f989b025a12dc 100644 --- a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll +++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll @@ -671,6 +671,20 @@ define void @test_pasubu_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) { ret void } +; Test for splat +define void @test_non_const_splat_i32(ptr %ret_ptr, ptr %a_ptr, i32 %elt) { +; CHECK-LABEL: test_non_const_splat_i32: +; CHECK: # %bb.0: +; CHECK-NEXT: padd.ws a1, zero, a2 +; CHECK-NEXT: sd a1, 0(a0) +; CHECK-NEXT: ret + %a = load <2 x i32>, ptr %a_ptr + %insert = insertelement <2 x i32> poison, i32 %elt, i32 0 + %splat = shufflevector <2 x i32> %insert, <2 x i32> poison, <2 x i32> zeroinitializer + store <2 x i32> %splat, ptr %ret_ptr + ret void +} + ; Intrinsic declarations declare <4 x i16> @llvm.sadd.sat.v4i16(<4 x i16>, <4 x i16>) declare <4 x i16> @llvm.uadd.sat.v4i16(<4 x i16>, <4 x i16>)