Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,17 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
addRegisterClass(MVT::riscv_nxv32i8x2, &RISCV::VRN2M4RegClass);
}

// fixed vector is stored in GPRs for P extension packed operations
if (Subtarget.hasStdExtP()) {
addRegisterClass(MVT::v2i16, &RISCV::GPRRegClass);
addRegisterClass(MVT::v4i8, &RISCV::GPRRegClass);
if (Subtarget.is64Bit()) {
addRegisterClass(MVT::v2i32, &RISCV::GPRRegClass);
addRegisterClass(MVT::v4i16, &RISCV::GPRRegClass);
addRegisterClass(MVT::v8i8, &RISCV::GPRRegClass);
}
}

// Compute derived properties from the register classes.
computeRegisterProperties(STI.getRegisterInfo());

Expand Down Expand Up @@ -479,6 +490,24 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::FTRUNC, ISD::FRINT, ISD::FROUND,
ISD::FROUNDEVEN, ISD::FCANONICALIZE};

if (Subtarget.hasStdExtP()) {
// load/store are already handled by pattern matching
SmallVector<MVT, 2> VTs = {MVT::v2i16, MVT::v4i8};
if (Subtarget.is64Bit())
VTs.append({MVT::v2i32, MVT::v4i16, MVT::v8i8});
for (auto VT : VTs) {
setOperationAction(ISD::UADDSAT, VT, Legal);
setOperationAction(ISD::SADDSAT, VT, Legal);
setOperationAction(ISD::USUBSAT, VT, Legal);
setOperationAction(ISD::SSUBSAT, VT, Legal);
setOperationAction(ISD::SSHLSAT, VT, Legal);
setOperationAction(ISD::USHLSAT, VT, Legal);
setOperationAction(ISD::BITCAST, VT, Custom);
setOperationAction({ISD::AVGFLOORS, ISD::AVGFLOORU}, VT, Legal);
setOperationAction({ISD::ABDS, ISD::ABDU}, VT, Legal);
}
}

if (Subtarget.hasStdExtZfbfmin()) {
setOperationAction(ISD::BITCAST, MVT::i16, Custom);
setOperationAction(ISD::ConstantFP, MVT::bf16, Expand);
Expand Down
124 changes: 124 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfoP.td
Original file line number Diff line number Diff line change
Expand Up @@ -1455,3 +1455,127 @@ let Predicates = [HasStdExtP, IsRV32] in {
def PMAXU_DW : RVPPairBinaryExchanged_rr<0b1111, 0b01, "pmaxu.dw">;
def PMAXU_DB : RVPPairBinaryExchanged_rr<0b1111, 0b10, "pmaxu.db">;
} // Predicates = [HasStdExtP, IsRV32]

let Predicates = [HasStdExtP, IsRV64] in {
// Basic arithmetic patterns for v4i16 (16-bit elements in 64-bit GPR)
def: Pat<(v4i16 (add v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PADD_H") GPR:$rs1, GPR:$rs2)>;
def: Pat<(v4i16 (sub v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PSUB_H") GPR:$rs1, GPR:$rs2)>;

// Saturating add/sub patterns for v4i16
def: Pat<(v4i16 (saddsat v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PSADD_H") GPR:$rs1, GPR:$rs2)>;
def: Pat<(v4i16 (uaddsat v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PSADDU_H") GPR:$rs1, GPR:$rs2)>;
def: Pat<(v4i16 (ssubsat v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PSSUB_H") GPR:$rs1, GPR:$rs2)>;
def: Pat<(v4i16 (usubsat v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PSSUBU_H") GPR:$rs1, GPR:$rs2)>;

// Averaging patterns for v4i16
def: Pat<(v4i16 (avgfloors v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PAADD_H") GPR:$rs1, GPR:$rs2)>;
def: Pat<(v4i16 (avgflooru v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PAADDU_H") GPR:$rs1, GPR:$rs2)>;

// Averaging subtraction patterns for v4i16
// PASUB_H: signed (a - b) >> 1
def: Pat<(v4i16 (sra (sub v4i16:$rs1, v4i16:$rs2), (v4i16 (build_vector (XLenVT 1))))),
(!cast<Instruction>("PASUB_H") GPR:$rs1, GPR:$rs2)>;
// PASUBU_H: unsigned (a - b) >> 1
def: Pat<(v4i16 (srl (sub v4i16:$rs1, v4i16:$rs2), (v4i16 (build_vector (XLenVT 1))))),
(!cast<Instruction>("PASUBU_H") GPR:$rs1, GPR:$rs2)>;

// Absolute difference patterns for v4i16
def: Pat<(v4i16 (abds v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PDIF_H") GPR:$rs1, GPR:$rs2)>;
def: Pat<(v4i16 (abdu v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PDIFU_H") GPR:$rs1, GPR:$rs2)>;

// Basic arithmetic patterns for v8i8 (8-bit elements in 64-bit GPR)
def: Pat<(v8i8 (add v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PADD_B") GPR:$rs1, GPR:$rs2)>;
def: Pat<(v8i8 (sub v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PSUB_B") GPR:$rs1, GPR:$rs2)>;

// Saturating add/sub patterns for v8i8
def: Pat<(v8i8 (saddsat v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PSADD_B") GPR:$rs1, GPR:$rs2)>;
def: Pat<(v8i8 (uaddsat v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PSADDU_B") GPR:$rs1, GPR:$rs2)>;
def: Pat<(v8i8 (ssubsat v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PSSUB_B") GPR:$rs1, GPR:$rs2)>;
def: Pat<(v8i8 (usubsat v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PSSUBU_B") GPR:$rs1, GPR:$rs2)>;

// Averaging patterns for v8i8
def: Pat<(v8i8 (avgfloors v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PAADD_B") GPR:$rs1, GPR:$rs2)>;
def: Pat<(v8i8 (avgflooru v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PAADDU_B") GPR:$rs1, GPR:$rs2)>;

// Averaging subtraction patterns for v8i8
// PASUB_B: signed (a - b) >> 1
def: Pat<(v8i8 (sra (sub v8i8:$rs1, v8i8:$rs2), (v8i8 (build_vector (XLenVT 1))))),
(!cast<Instruction>("PASUB_B") GPR:$rs1, GPR:$rs2)>;
// PASUBU_B: unsigned (a - b) >> 1
def: Pat<(v8i8 (srl (sub v8i8:$rs1, v8i8:$rs2), (v8i8 (build_vector (XLenVT 1))))),
(!cast<Instruction>("PASUBU_B") GPR:$rs1, GPR:$rs2)>;

// Absolute difference patterns for v8i8
def: Pat<(v8i8 (abds v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PDIF_B") GPR:$rs1, GPR:$rs2)>;
def: Pat<(v8i8 (abdu v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PDIFU_B") GPR:$rs1, GPR:$rs2)>;

// Load/Store patterns for v4i16 and v8i8 (use regular GPR load/store since they're in GPRs)
def : StPat<store, SD, GPR, v4i16>;
def : LdPat<load, LD, v4i16>;
def : StPat<store, SD, GPR, v8i8>;
def : LdPat<load, LD, v8i8>;

// Load/Store patterns for v2i32 (32-bit elements in 64-bit GPR)
def : StPat<store, SD, GPR, v2i32>;
def : LdPat<load, LD, v2i32>;
} // Predicates = [HasStdExtP, IsRV64]

let Predicates = [HasStdExtP, IsRV32] in {
// Basic arithmetic patterns for v2i16 (16-bit elements in 32-bit GPR)
def: Pat<(v2i16 (add v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PADD_H") GPR:$rs1, GPR:$rs2)>;
def: Pat<(v2i16 (sub v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PSUB_H") GPR:$rs1, GPR:$rs2)>;

// Saturating add/sub patterns for v2i16
def: Pat<(v2i16 (saddsat v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PSADD_H") GPR:$rs1, GPR:$rs2)>;
def: Pat<(v2i16 (uaddsat v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PSADDU_H") GPR:$rs1, GPR:$rs2)>;
def: Pat<(v2i16 (ssubsat v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PSSUB_H") GPR:$rs1, GPR:$rs2)>;
def: Pat<(v2i16 (usubsat v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PSSUBU_H") GPR:$rs1, GPR:$rs2)>;

// Averaging patterns for v2i16
def: Pat<(v2i16 (avgfloors v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PAADD_H") GPR:$rs1, GPR:$rs2)>;
def: Pat<(v2i16 (avgflooru v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PAADDU_H") GPR:$rs1, GPR:$rs2)>;

// Averaging subtraction patterns for v2i16
// PASUB_H: signed (a - b) >> 1
def: Pat<(v2i16 (sra (sub v2i16:$rs1, v2i16:$rs2), (v2i16 (build_vector (XLenVT 1))))),
(!cast<Instruction>("PASUB_H") GPR:$rs1, GPR:$rs2)>;
// PASUBU_H: unsigned (a - b) >> 1
def: Pat<(v2i16 (srl (sub v2i16:$rs1, v2i16:$rs2), (v2i16 (build_vector (XLenVT 1))))),
(!cast<Instruction>("PASUBU_H") GPR:$rs1, GPR:$rs2)>;

// Absolute difference patterns for v2i16
def: Pat<(v2i16 (abds v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PDIF_H") GPR:$rs1, GPR:$rs2)>;
def: Pat<(v2i16 (abdu v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PDIFU_H") GPR:$rs1, GPR:$rs2)>;

// Basic arithmetic patterns for v4i8 (8-bit elements in 32-bit GPR)
def: Pat<(v4i8 (add v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PADD_B") GPR:$rs1, GPR:$rs2)>;
def: Pat<(v4i8 (sub v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PSUB_B") GPR:$rs1, GPR:$rs2)>;

// Saturating add/sub patterns for v4i8
def: Pat<(v4i8 (saddsat v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PSADD_B") GPR:$rs1, GPR:$rs2)>;
def: Pat<(v4i8 (uaddsat v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PSADDU_B") GPR:$rs1, GPR:$rs2)>;
def: Pat<(v4i8 (ssubsat v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PSSUB_B") GPR:$rs1, GPR:$rs2)>;
def: Pat<(v4i8 (usubsat v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PSSUBU_B") GPR:$rs1, GPR:$rs2)>;

// Averaging patterns for v4i8
def: Pat<(v4i8 (avgfloors v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PAADD_B") GPR:$rs1, GPR:$rs2)>;
def: Pat<(v4i8 (avgflooru v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PAADDU_B") GPR:$rs1, GPR:$rs2)>;

// Averaging subtraction patterns for v4i8
// PASUB_B: signed (a - b) >> 1
def: Pat<(v4i8 (sra (sub v4i8:$rs1, v4i8:$rs2), (v4i8 (build_vector (XLenVT 1))))),
(!cast<Instruction>("PASUB_B") GPR:$rs1, GPR:$rs2)>;
// PASUBU_B: unsigned (a - b) >> 1
def: Pat<(v4i8 (srl (sub v4i8:$rs1, v4i8:$rs2), (v4i8 (build_vector (XLenVT 1))))),
(!cast<Instruction>("PASUBU_B") GPR:$rs1, GPR:$rs2)>;

// Absolute difference patterns for v4i8
def: Pat<(v4i8 (abds v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PDIF_B") GPR:$rs1, GPR:$rs2)>;
def: Pat<(v4i8 (abdu v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PDIFU_B") GPR:$rs1, GPR:$rs2)>;

// Load/Store patterns for v2i16 and v4i8 (use regular GPR load/store since they're in GPRs)
def : StPat<store, SW, GPR, v2i16>;
def : LdPat<load, LW, v2i16>;
def : StPat<store, SW, GPR, v4i8>;
def : LdPat<load, LW, v4i8>;
} // Predicates = [HasStdExtP, IsRV32]
6 changes: 5 additions & 1 deletion llvm/lib/Target/RISCV/RISCVRegisterInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,11 @@ class RISCVRegisterClass<list<ValueType> regTypes, int align, dag regList>
}

class GPRRegisterClass<dag regList>
: RISCVRegisterClass<[XLenVT, XLenFVT], 32, regList> {
: RISCVRegisterClass<[XLenVT, XLenFVT,
// P extension packed vector types:
// RV32: v2i16, v4i8
// RV64: v2i32, v4i16, v8i8
v2i16, v4i8, v2i32, v4i16, v8i8], 32, regList> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These types need to be controlled by HwMode like XLenVT and XLenFVT.

let RegInfos = XLenRI;
}

Expand Down
21 changes: 21 additions & 0 deletions llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,13 @@ InstructionCost RISCVTTIImpl::getScalarizationOverhead(
if (isa<ScalableVectorType>(Ty))
return InstructionCost::getInvalid();

// TODO: Add proper cost model for P extension fixed vectors (e.g., v4i16)
// For now, skip all fixed vector cost analysis when P extension is available
// to avoid crashes in getMinRVVVectorSizeInBits()
if (ST->hasStdExtP() && isa<FixedVectorType>(Ty)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to check EnablePExtCodeGen.

return 1; // Treat as single instruction cost for now
}

// A build_vector (which is m1 sized or smaller) can be done in no
// worse than one vslide1down.vx per element in the type. We could
// in theory do an explode_vector in the inverse manner, but our
Expand Down Expand Up @@ -1625,6 +1632,13 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
if (!IsVectorType)
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);

// TODO: Add proper cost model for P extension fixed vectors (e.g., v4i16)
// For now, skip all fixed vector cost analysis when P extension is available
// to avoid crashes in getMinRVVVectorSizeInBits()
if (ST->hasStdExtP() && (isa<FixedVectorType>(Dst) || isa<FixedVectorType>(Src))) {
return 1; // Treat as single instruction cost for now
}

// FIXME: Need to compute legalizing cost for illegal types. The current
// code handles only legal types and those which can be trivially
// promoted to legal.
Expand Down Expand Up @@ -2321,6 +2335,13 @@ InstructionCost RISCVTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
const Value *Op1) const {
assert(Val->isVectorTy() && "This must be a vector type");

// TODO: Add proper cost model for P extension fixed vectors (e.g., v4i16)
// For now, skip all fixed vector cost analysis when P extension is available
// to avoid crashes in getMinRVVVectorSizeInBits()
if (ST->hasStdExtP() && isa<FixedVectorType>(Val)) {
return 1; // Treat as single instruction cost for now
}

if (Opcode != Instruction::ExtractElement &&
Opcode != Instruction::InsertElement)
return BaseT::getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1);
Expand Down
Loading
Loading