-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[RISCV][llvm] Preliminary P extension codegen support #162668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This is the initial support of P extension codegen, it only includes small part of instructions: PADD_H, PADD_B, PSADD_H, PSADD_B, PAADD_H, PAADD_B, PSADDU_H, PSADDU_B, PAADDU_H, PAADDU_B, PSUB_H, PSUB_B, PDIF_H, PDIF_B, PSSUB_H, PSSUB_B, PASUB_H, PASUB_B, PDIFU_H, PDIFU_B, PSSUBU_H, PSSUBU_B, PASUBU_H, PASUBU_B
|
@llvm/pr-subscribers-backend-risc-v Author: Brandon Wu (4vtomat) ChangesThis is the initial support of P extension codegen, it only includes Patch is 41.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/162668.diff 6 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 7123a2d706787..1eb8c9457ee6a 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -279,6 +279,17 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
addRegisterClass(MVT::riscv_nxv32i8x2, &RISCV::VRN2M4RegClass);
}
+ // fixed vector is stored in GPRs for P extension packed operations
+ if (Subtarget.hasStdExtP()) {
+ addRegisterClass(MVT::v2i16, &RISCV::GPRRegClass);
+ addRegisterClass(MVT::v4i8, &RISCV::GPRRegClass);
+ if (Subtarget.is64Bit()) {
+ addRegisterClass(MVT::v2i32, &RISCV::GPRRegClass);
+ addRegisterClass(MVT::v4i16, &RISCV::GPRRegClass);
+ addRegisterClass(MVT::v8i8, &RISCV::GPRRegClass);
+ }
+ }
+
// Compute derived properties from the register classes.
computeRegisterProperties(STI.getRegisterInfo());
@@ -479,6 +490,24 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::FTRUNC, ISD::FRINT, ISD::FROUND,
ISD::FROUNDEVEN, ISD::FCANONICALIZE};
+ if (Subtarget.hasStdExtP()) {
+ // load/store are already handled by pattern matching
+ SmallVector<MVT, 2> VTs = {MVT::v2i16, MVT::v4i8};
+ if (Subtarget.is64Bit())
+ VTs.append({MVT::v2i32, MVT::v4i16, MVT::v8i8});
+ for (auto VT : VTs) {
+ setOperationAction(ISD::UADDSAT, VT, Legal);
+ setOperationAction(ISD::SADDSAT, VT, Legal);
+ setOperationAction(ISD::USUBSAT, VT, Legal);
+ setOperationAction(ISD::SSUBSAT, VT, Legal);
+ setOperationAction(ISD::SSHLSAT, VT, Legal);
+ setOperationAction(ISD::USHLSAT, VT, Legal);
+ setOperationAction(ISD::BITCAST, VT, Custom);
+ setOperationAction({ISD::AVGFLOORS, ISD::AVGFLOORU}, VT, Legal);
+ setOperationAction({ISD::ABDS, ISD::ABDU}, VT, Legal);
+ }
+ }
+
if (Subtarget.hasStdExtZfbfmin()) {
setOperationAction(ISD::BITCAST, MVT::i16, Custom);
setOperationAction(ISD::ConstantFP, MVT::bf16, Expand);
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
index 7d8a9192d9847..c5e2f12aafb1e 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
@@ -1455,3 +1455,127 @@ let Predicates = [HasStdExtP, IsRV32] in {
def PMAXU_DW : RVPPairBinaryExchanged_rr<0b1111, 0b01, "pmaxu.dw">;
def PMAXU_DB : RVPPairBinaryExchanged_rr<0b1111, 0b10, "pmaxu.db">;
} // Predicates = [HasStdExtP, IsRV32]
+
+let Predicates = [HasStdExtP, IsRV64] in {
+ // Basic arithmetic patterns for v4i16 (16-bit elements in 64-bit GPR)
+ def: Pat<(v4i16 (add v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PADD_H") GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v4i16 (sub v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PSUB_H") GPR:$rs1, GPR:$rs2)>;
+
+ // Saturating add/sub patterns for v4i16
+ def: Pat<(v4i16 (saddsat v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PSADD_H") GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v4i16 (uaddsat v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PSADDU_H") GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v4i16 (ssubsat v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PSSUB_H") GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v4i16 (usubsat v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PSSUBU_H") GPR:$rs1, GPR:$rs2)>;
+
+ // Averaging patterns for v4i16
+ def: Pat<(v4i16 (avgfloors v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PAADD_H") GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v4i16 (avgflooru v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PAADDU_H") GPR:$rs1, GPR:$rs2)>;
+
+ // Averaging subtraction patterns for v4i16
+ // PASUB_H: signed (a - b) >> 1
+ def: Pat<(v4i16 (sra (sub v4i16:$rs1, v4i16:$rs2), (v4i16 (build_vector (XLenVT 1))))),
+ (!cast<Instruction>("PASUB_H") GPR:$rs1, GPR:$rs2)>;
+ // PASUBU_H: unsigned (a - b) >> 1
+ def: Pat<(v4i16 (srl (sub v4i16:$rs1, v4i16:$rs2), (v4i16 (build_vector (XLenVT 1))))),
+ (!cast<Instruction>("PASUBU_H") GPR:$rs1, GPR:$rs2)>;
+
+ // Absolute difference patterns for v4i16
+ def: Pat<(v4i16 (abds v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PDIF_H") GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v4i16 (abdu v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PDIFU_H") GPR:$rs1, GPR:$rs2)>;
+
+ // Basic arithmetic patterns for v8i8 (8-bit elements in 64-bit GPR)
+ def: Pat<(v8i8 (add v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PADD_B") GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v8i8 (sub v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PSUB_B") GPR:$rs1, GPR:$rs2)>;
+
+ // Saturating add/sub patterns for v8i8
+ def: Pat<(v8i8 (saddsat v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PSADD_B") GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v8i8 (uaddsat v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PSADDU_B") GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v8i8 (ssubsat v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PSSUB_B") GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v8i8 (usubsat v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PSSUBU_B") GPR:$rs1, GPR:$rs2)>;
+
+ // Averaging patterns for v8i8
+ def: Pat<(v8i8 (avgfloors v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PAADD_B") GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v8i8 (avgflooru v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PAADDU_B") GPR:$rs1, GPR:$rs2)>;
+
+ // Averaging subtraction patterns for v8i8
+ // PASUB_B: signed (a - b) >> 1
+ def: Pat<(v8i8 (sra (sub v8i8:$rs1, v8i8:$rs2), (v8i8 (build_vector (XLenVT 1))))),
+ (!cast<Instruction>("PASUB_B") GPR:$rs1, GPR:$rs2)>;
+ // PASUBU_B: unsigned (a - b) >> 1
+ def: Pat<(v8i8 (srl (sub v8i8:$rs1, v8i8:$rs2), (v8i8 (build_vector (XLenVT 1))))),
+ (!cast<Instruction>("PASUBU_B") GPR:$rs1, GPR:$rs2)>;
+
+ // Absolute difference patterns for v8i8
+ def: Pat<(v8i8 (abds v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PDIF_B") GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v8i8 (abdu v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PDIFU_B") GPR:$rs1, GPR:$rs2)>;
+
+ // Load/Store patterns for v4i16 and v8i8 (use regular GPR load/store since they're in GPRs)
+ def : StPat<store, SD, GPR, v4i16>;
+ def : LdPat<load, LD, v4i16>;
+ def : StPat<store, SD, GPR, v8i8>;
+ def : LdPat<load, LD, v8i8>;
+
+ // Load/Store patterns for v2i32 (32-bit elements in 64-bit GPR)
+ def : StPat<store, SD, GPR, v2i32>;
+ def : LdPat<load, LD, v2i32>;
+} // Predicates = [HasStdExtP, IsRV64]
+
+let Predicates = [HasStdExtP, IsRV32] in {
+ // Basic arithmetic patterns for v2i16 (16-bit elements in 32-bit GPR)
+ def: Pat<(v2i16 (add v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PADD_H") GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v2i16 (sub v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PSUB_H") GPR:$rs1, GPR:$rs2)>;
+
+ // Saturating add/sub patterns for v2i16
+ def: Pat<(v2i16 (saddsat v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PSADD_H") GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v2i16 (uaddsat v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PSADDU_H") GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v2i16 (ssubsat v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PSSUB_H") GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v2i16 (usubsat v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PSSUBU_H") GPR:$rs1, GPR:$rs2)>;
+
+ // Averaging patterns for v2i16
+ def: Pat<(v2i16 (avgfloors v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PAADD_H") GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v2i16 (avgflooru v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PAADDU_H") GPR:$rs1, GPR:$rs2)>;
+
+ // Averaging subtraction patterns for v2i16
+ // PASUB_H: signed (a - b) >> 1
+ def: Pat<(v2i16 (sra (sub v2i16:$rs1, v2i16:$rs2), (v2i16 (build_vector (XLenVT 1))))),
+ (!cast<Instruction>("PASUB_H") GPR:$rs1, GPR:$rs2)>;
+ // PASUBU_H: unsigned (a - b) >> 1
+ def: Pat<(v2i16 (srl (sub v2i16:$rs1, v2i16:$rs2), (v2i16 (build_vector (XLenVT 1))))),
+ (!cast<Instruction>("PASUBU_H") GPR:$rs1, GPR:$rs2)>;
+
+ // Absolute difference patterns for v2i16
+ def: Pat<(v2i16 (abds v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PDIF_H") GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v2i16 (abdu v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PDIFU_H") GPR:$rs1, GPR:$rs2)>;
+
+ // Basic arithmetic patterns for v4i8 (8-bit elements in 32-bit GPR)
+ def: Pat<(v4i8 (add v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PADD_B") GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v4i8 (sub v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PSUB_B") GPR:$rs1, GPR:$rs2)>;
+
+ // Saturating add/sub patterns for v4i8
+ def: Pat<(v4i8 (saddsat v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PSADD_B") GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v4i8 (uaddsat v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PSADDU_B") GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v4i8 (ssubsat v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PSSUB_B") GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v4i8 (usubsat v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PSSUBU_B") GPR:$rs1, GPR:$rs2)>;
+
+ // Averaging patterns for v4i8
+ def: Pat<(v4i8 (avgfloors v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PAADD_B") GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v4i8 (avgflooru v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PAADDU_B") GPR:$rs1, GPR:$rs2)>;
+
+ // Averaging subtraction patterns for v4i8
+ // PASUB_B: signed (a - b) >> 1
+ def: Pat<(v4i8 (sra (sub v4i8:$rs1, v4i8:$rs2), (v4i8 (build_vector (XLenVT 1))))),
+ (!cast<Instruction>("PASUB_B") GPR:$rs1, GPR:$rs2)>;
+ // PASUBU_B: unsigned (a - b) >> 1
+ def: Pat<(v4i8 (srl (sub v4i8:$rs1, v4i8:$rs2), (v4i8 (build_vector (XLenVT 1))))),
+ (!cast<Instruction>("PASUBU_B") GPR:$rs1, GPR:$rs2)>;
+
+ // Absolute difference patterns for v4i8
+ def: Pat<(v4i8 (abds v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PDIF_B") GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v4i8 (abdu v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PDIFU_B") GPR:$rs1, GPR:$rs2)>;
+
+ // Load/Store patterns for v2i16 and v4i8 (use regular GPR load/store since they're in GPRs)
+ def : StPat<store, SW, GPR, v2i16>;
+ def : LdPat<load, LW, v2i16>;
+ def : StPat<store, SW, GPR, v4i8>;
+ def : LdPat<load, LW, v4i8>;
+} // Predicates = [HasStdExtP, IsRV32]
diff --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.td b/llvm/lib/Target/RISCV/RISCVRegisterInfo.td
index 6605a5ccdfde2..fcbb93a55375b 100644
--- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.td
+++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.td
@@ -238,7 +238,11 @@ class RISCVRegisterClass<list<ValueType> regTypes, int align, dag regList>
}
class GPRRegisterClass<dag regList>
- : RISCVRegisterClass<[XLenVT, XLenFVT], 32, regList> {
+ : RISCVRegisterClass<[XLenVT, XLenFVT,
+ // P extension packed vector types:
+ // RV32: v2i16, v4i8
+ // RV64: v2i32, v4i16, v8i8
+ v2i16, v4i8, v2i32, v4i16, v8i8], 32, regList> {
let RegInfos = XLenRI;
}
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 7bc0b5b394828..e669175a3d8e1 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -969,6 +969,13 @@ InstructionCost RISCVTTIImpl::getScalarizationOverhead(
if (isa<ScalableVectorType>(Ty))
return InstructionCost::getInvalid();
+ // TODO: Add proper cost model for P extension fixed vectors (e.g., v4i16)
+ // For now, skip all fixed vector cost analysis when P extension is available
+ // to avoid crashes in getMinRVVVectorSizeInBits()
+ if (ST->hasStdExtP() && isa<FixedVectorType>(Ty)) {
+ return 1; // Treat as single instruction cost for now
+ }
+
// A build_vector (which is m1 sized or smaller) can be done in no
// worse than one vslide1down.vx per element in the type. We could
// in theory do an explode_vector in the inverse manner, but our
@@ -1625,6 +1632,13 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
if (!IsVectorType)
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
+ // TODO: Add proper cost model for P extension fixed vectors (e.g., v4i16)
+ // For now, skip all fixed vector cost analysis when P extension is available
+ // to avoid crashes in getMinRVVVectorSizeInBits()
+ if (ST->hasStdExtP() && (isa<FixedVectorType>(Dst) || isa<FixedVectorType>(Src))) {
+ return 1; // Treat as single instruction cost for now
+ }
+
// FIXME: Need to compute legalizing cost for illegal types. The current
// code handles only legal types and those which can be trivially
// promoted to legal.
@@ -2321,6 +2335,13 @@ InstructionCost RISCVTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
const Value *Op1) const {
assert(Val->isVectorTy() && "This must be a vector type");
+ // TODO: Add proper cost model for P extension fixed vectors (e.g., v4i16)
+ // For now, skip all fixed vector cost analysis when P extension is available
+ // to avoid crashes in getMinRVVVectorSizeInBits()
+ if (ST->hasStdExtP() && isa<FixedVectorType>(Val)) {
+ return 1; // Treat as single instruction cost for now
+ }
+
if (Opcode != Instruction::ExtractElement &&
Opcode != Instruction::InsertElement)
return BaseT::getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1);
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
new file mode 100644
index 0000000000000..8a4ab1d545f41
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
@@ -0,0 +1,426 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv32 -mattr=+experimental-p -verify-machineinstrs < %s | FileCheck %s
+
+; Test basic add/sub operations for v2i16
+define void @test_padd_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_padd_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: padd.h a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i16>, ptr %a_ptr
+ %b = load <2 x i16>, ptr %b_ptr
+ %res = add <2 x i16> %a, %b
+ store <2 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+define void @test_psub_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_psub_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: psub.h a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i16>, ptr %a_ptr
+ %b = load <2 x i16>, ptr %b_ptr
+ %res = sub <2 x i16> %a, %b
+ store <2 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test basic add/sub operations for v4i8
+define void @test_padd_b(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_padd_b:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: padd.b a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <4 x i8>, ptr %a_ptr
+ %b = load <4 x i8>, ptr %b_ptr
+ %res = add <4 x i8> %a, %b
+ store <4 x i8> %res, ptr %ret_ptr
+ ret void
+}
+
+define void @test_psub_b(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_psub_b:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: psub.b a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <4 x i8>, ptr %a_ptr
+ %b = load <4 x i8>, ptr %b_ptr
+ %res = sub <4 x i8> %a, %b
+ store <4 x i8> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test saturating add operations for v2i16
+define void @test_psadd_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_psadd_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: psadd.h a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i16>, ptr %a_ptr
+ %b = load <2 x i16>, ptr %b_ptr
+ %res = call <2 x i16> @llvm.sadd.sat.v2i16(<2 x i16> %a, <2 x i16> %b)
+ store <2 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+define void @test_psaddu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_psaddu_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: psaddu.h a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i16>, ptr %a_ptr
+ %b = load <2 x i16>, ptr %b_ptr
+ %res = call <2 x i16> @llvm.uadd.sat.v2i16(<2 x i16> %a, <2 x i16> %b)
+ store <2 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test saturating sub operations for v2i16
+define void @test_pssub_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pssub_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: pssub.h a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i16>, ptr %a_ptr
+ %b = load <2 x i16>, ptr %b_ptr
+ %res = call <2 x i16> @llvm.ssub.sat.v2i16(<2 x i16> %a, <2 x i16> %b)
+ store <2 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+define void @test_pssubu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pssubu_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: pssubu.h a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i16>, ptr %a_ptr
+ %b = load <2 x i16>, ptr %b_ptr
+ %res = call <2 x i16> @llvm.usub.sat.v2i16(<2 x i16> %a, <2 x i16> %b)
+ store <2 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test saturating add operations for v4i8
+define void @test_psadd_b(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_psadd_b:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: psadd.b a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <4 x i8>, ptr %a_ptr
+ %b = load <4 x i8>, ptr %b_ptr
+ %res = call <4 x i8> @llvm.sadd.sat.v4i8(<4 x i8> %a, <4 x i8> %b)
+ store <4 x i8> %res, ptr %ret_ptr
+ ret void
+}
+
+define void @test_psaddu_b(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_psaddu_b:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: psaddu.b a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <4 x i8>, ptr %a_ptr
+ %b = load <4 x i8>, ptr %b_ptr
+ %res = call <4 x i8> @llvm.uadd.sat.v4i8(<4 x i8> %a, <4 x i8> %b)
+ store <4 x i8> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test saturating sub operations for v4i8
+define void @test_pssub_b(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pssub_b:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: pssub.b a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <4 x i8>, ptr %a_ptr
+ %b = load <4 x i8>, ptr %b_ptr
+ %res = call <4 x i8> @llvm.ssub.sat.v4i8(<4 x i8> %a, <4 x i8> %b)
+ store <4 x i8> %res, ptr %ret_ptr
+ ret void
+}
+
+define void @test_pssubu_b(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pssubu_b:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: pssubu.b a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <4 x i8>, ptr %a_ptr
+ %b = load <4 x i8>, ptr %b_ptr
+ %res = call <4 x i8> @llvm.usub.sat.v4i8(<4 x i8> %a, <4 x i8> %b)
+ store <4 x i8> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test averaging floor signed operations for v2i16
+define void @test_paadd_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_paadd_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: paadd.h a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i16>, ptr %a_ptr
+ %b = load <2 x i16>, ptr %b_ptr
+ %ext.a = sext <2 x i16> %a to <2 x i32>
+ %ext.b = sext <2 x i16> %b to <2 x i32>
+ %add = add nsw <2 x i32> %ext.a, %ext.b
+ %shift = ashr <2 x i32> %add, <i32 1, i32 1>
+ %res = trunc <2 x i32> %shift to <2 x i16>
+ store <2 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test averaging floor unsigned operations for v2i16
+define void @test_paaddu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_paaddu_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: paaddu.h a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = l...
[truncated]
|
You can test this locally with the following command:git-clang-format --diff origin/main HEAD --extensions h,cpp -- llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp llvm/lib/Target/RISCV/RISCVISelLowering.cpp llvm/lib/Target/RISCV/RISCVISelLowering.h llvm/lib/Target/RISCV/RISCVInstrInfo.cpp llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp --diff_from_common_commit
View the diff from clang-format here.diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index e3c98df74..94ad1d387 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -15023,10 +15023,10 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
MVT NewVT = MVT::v4i16;
if (VT == MVT::v4i8)
NewVT = MVT::v8i8;
- Op0 = DAG.getBitcast(MVT::i32, Op0);
+ Op0 = DAG.getBitcast(MVT::i32, Op0);
Op0 = DAG.getSExtOrTrunc(Op0, DL, MVT::i64);
Op0 = DAG.getBitcast(NewVT, Op0);
- Op1 = DAG.getBitcast(MVT::i32, Op1);
+ Op1 = DAG.getBitcast(MVT::i32, Op1);
Op1 = DAG.getSExtOrTrunc(Op1, DL, MVT::i64);
Op1 = DAG.getBitcast(NewVT, Op1);
Results.push_back(DAG.getNode(N->getOpcode(), DL, NewVT, {Op0, Op1}));
|
|
I don't think we can enable only partial vector support with just |
|
|
||
| bool isSImm5() const { return isSImm<5>(); } | ||
| bool isSImm6() const { return isSImm<6>(); } | ||
| bool isSImm8() const { return isSImm<8>(); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need this change?
|
|
||
| TargetLoweringBase::LegalizeTypeAction | ||
| RISCVTargetLowering::getPreferredVectorAction(MVT VT) const { | ||
| if (Subtarget.hasStdExtP() && Subtarget.is64Bit()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing EnablePExtCodeGen
|
|
||
| // Check immediate range based on vector type | ||
| if (VT == MVT::v8i8 || VT == MVT::v4i8) | ||
| // PLI_B uses 8-bit unsigned immediate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whether the immediate signed or unsigned doesn't really matter it fills the whole element. So I think you can accept isInt<8> || isUInt<8> here.
But the description of simm8_unsigned says that the canonical form is [-128,127] so you need to convert UInt8 to Int8.
| // TODO: Add proper cost model for P extension fixed vectors (e.g., v4i16) | ||
| // For now, skip all fixed vector cost analysis when P extension is available | ||
| // to avoid crashes in getMinRVVVectorSizeInBits() | ||
| if (ST->hasStdExtP() && isa<FixedVectorType>(Ty)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to check EnablePExtCodeGen.
| def PMAXU_DB : RVPPairBinaryExchanged_rr<0b1111, 0b10, "pmaxu.db">; | ||
| } // Predicates = [HasStdExtP, IsRV32] | ||
|
|
||
| def SDT_RISCVPLI : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisInt<1>]>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs SDTCIsVec<0> too.
| Results.push_back(DAG.getBitcast(MVT::v4i16, ExtLoad)); | ||
| else if (N->getValueType(0) == MVT::v4i8) | ||
| Results.push_back(DAG.getBitcast(MVT::v8i8, ExtLoad)); | ||
| Results.push_back(ExtLoad.getValue(1)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can't push the chain if the type isn't v2i16 or v4i8.
| SDValue StoredVal = Store->getValue(); | ||
| EVT VT = StoredVal.getValueType(); | ||
| if (Subtarget.hasStdExtP()) { | ||
| if (VT == MVT::v2i16 || VT == MVT::v4i8) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure why you need to do this. Shouldn't the type legalizer do this?
| MVT NewVT = MVT::v4i16; | ||
| if (VT == MVT::v4i8) | ||
| NewVT = MVT::v8i8; | ||
| Op0 = DAG.getBitcast(MVT::i32, Op0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be using CONCAT_VECTORS with ISD::UNDEF to widen the inputs. You not go through scalar.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious about why we can't go through scalar? isn't cast simply a no-op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The CONCAT_VECTORS should get removed by the type legalizer when it widens the surrounding operations. Leaving just v8i8 or v4i16 vector operations except for loads/stores. If you go through a bitcast to scalar, the type legalizer can't delete them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks for clarifying!
| // Determine the instruction based on type and signedness | ||
| unsigned Opc; | ||
| MVT VecVT = VT.getSimpleVT(); | ||
| if (VecVT == MVT::v4i16 || VecVT == MVT::v2i16 || VecVT == MVT::v8i8 || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check type at the beginning?
| const RISCVSubtarget &Subtarget) { | ||
| SDValue N0 = N->getOperand(0); | ||
| EVT VT = N->getValueType(0); | ||
| if (!Subtarget.hasStdExtP() || !VT.isFixedLengthVector()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't you already check hasStdExtP at the caller?
This is the initial support of P extension codegen, it only includes
small part of instructions:
PADD_H, PADD_B,
PSADD_H, PSADD_B,
PAADD_H, PAADD_B,
PSADDU_H, PSADDU_B,
PAADDU_H, PAADDU_B,
PSUB_H, PSUB_B,
PDIF_H, PDIF_B,
PSSUB_H, PSSUB_B,
PASUB_H, PASUB_B,
PDIFU_H, PDIFU_B,
PSSUBU_H, PSSUBU_B,
PASUBU_H, PASUBU_B