Skip to content

Commit ef7900d

Browse files
committed
[RISCV][llvm] Preliminary P extension codegen support
This is the initial support of P extension codegen, it only includes small part of instructions: PADD_H, PADD_B, PSADD_H, PSADD_B, PAADD_H, PAADD_B, PSADDU_H, PSADDU_B, PAADDU_H, PAADDU_B, PSUB_H, PSUB_B, PDIF_H, PDIF_B, PSSUB_H, PSSUB_B, PASUB_H, PASUB_B, PDIFU_H, PDIFU_B, PSSUBU_H, PSSUBU_B, PASUBU_H, PASUBU_B
1 parent 45757b9 commit ef7900d

File tree

6 files changed

+1035
-1
lines changed

6 files changed

+1035
-1
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,17 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
279279
addRegisterClass(MVT::riscv_nxv32i8x2, &RISCV::VRN2M4RegClass);
280280
}
281281

282+
// fixed vector is stored in GPRs for P extension packed operations
283+
if (Subtarget.hasStdExtP()) {
284+
addRegisterClass(MVT::v2i16, &RISCV::GPRRegClass);
285+
addRegisterClass(MVT::v4i8, &RISCV::GPRRegClass);
286+
if (Subtarget.is64Bit()) {
287+
addRegisterClass(MVT::v2i32, &RISCV::GPRRegClass);
288+
addRegisterClass(MVT::v4i16, &RISCV::GPRRegClass);
289+
addRegisterClass(MVT::v8i8, &RISCV::GPRRegClass);
290+
}
291+
}
292+
282293
// Compute derived properties from the register classes.
283294
computeRegisterProperties(STI.getRegisterInfo());
284295

@@ -479,6 +490,24 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
479490
ISD::FTRUNC, ISD::FRINT, ISD::FROUND,
480491
ISD::FROUNDEVEN, ISD::FCANONICALIZE};
481492

493+
if (Subtarget.hasStdExtP()) {
494+
// load/store are already handled by pattern matching
495+
SmallVector<MVT, 2> VTs = {MVT::v2i16, MVT::v4i8};
496+
if (Subtarget.is64Bit())
497+
VTs.append({MVT::v2i32, MVT::v4i16, MVT::v8i8});
498+
for (auto VT : VTs) {
499+
setOperationAction(ISD::UADDSAT, VT, Legal);
500+
setOperationAction(ISD::SADDSAT, VT, Legal);
501+
setOperationAction(ISD::USUBSAT, VT, Legal);
502+
setOperationAction(ISD::SSUBSAT, VT, Legal);
503+
setOperationAction(ISD::SSHLSAT, VT, Legal);
504+
setOperationAction(ISD::USHLSAT, VT, Legal);
505+
setOperationAction(ISD::BITCAST, VT, Custom);
506+
setOperationAction({ISD::AVGFLOORS, ISD::AVGFLOORU}, VT, Legal);
507+
setOperationAction({ISD::ABDS, ISD::ABDU}, VT, Legal);
508+
}
509+
}
510+
482511
if (Subtarget.hasStdExtZfbfmin()) {
483512
setOperationAction(ISD::BITCAST, MVT::i16, Custom);
484513
setOperationAction(ISD::ConstantFP, MVT::bf16, Expand);

llvm/lib/Target/RISCV/RISCVInstrInfoP.td

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1455,3 +1455,127 @@ let Predicates = [HasStdExtP, IsRV32] in {
14551455
def PMAXU_DW : RVPPairBinaryExchanged_rr<0b1111, 0b01, "pmaxu.dw">;
14561456
def PMAXU_DB : RVPPairBinaryExchanged_rr<0b1111, 0b10, "pmaxu.db">;
14571457
} // Predicates = [HasStdExtP, IsRV32]
1458+
1459+
let Predicates = [HasStdExtP, IsRV64] in {
1460+
// Basic arithmetic patterns for v4i16 (16-bit elements in 64-bit GPR)
1461+
def: Pat<(v4i16 (add v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PADD_H") GPR:$rs1, GPR:$rs2)>;
1462+
def: Pat<(v4i16 (sub v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PSUB_H") GPR:$rs1, GPR:$rs2)>;
1463+
1464+
// Saturating add/sub patterns for v4i16
1465+
def: Pat<(v4i16 (saddsat v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PSADD_H") GPR:$rs1, GPR:$rs2)>;
1466+
def: Pat<(v4i16 (uaddsat v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PSADDU_H") GPR:$rs1, GPR:$rs2)>;
1467+
def: Pat<(v4i16 (ssubsat v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PSSUB_H") GPR:$rs1, GPR:$rs2)>;
1468+
def: Pat<(v4i16 (usubsat v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PSSUBU_H") GPR:$rs1, GPR:$rs2)>;
1469+
1470+
// Averaging patterns for v4i16
1471+
def: Pat<(v4i16 (avgfloors v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PAADD_H") GPR:$rs1, GPR:$rs2)>;
1472+
def: Pat<(v4i16 (avgflooru v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PAADDU_H") GPR:$rs1, GPR:$rs2)>;
1473+
1474+
// Averaging subtraction patterns for v4i16
1475+
// PASUB_H: signed (a - b) >> 1
1476+
def: Pat<(v4i16 (sra (sub v4i16:$rs1, v4i16:$rs2), (v4i16 (build_vector (XLenVT 1))))),
1477+
(!cast<Instruction>("PASUB_H") GPR:$rs1, GPR:$rs2)>;
1478+
// PASUBU_H: unsigned (a - b) >> 1
1479+
def: Pat<(v4i16 (srl (sub v4i16:$rs1, v4i16:$rs2), (v4i16 (build_vector (XLenVT 1))))),
1480+
(!cast<Instruction>("PASUBU_H") GPR:$rs1, GPR:$rs2)>;
1481+
1482+
// Absolute difference patterns for v4i16
1483+
def: Pat<(v4i16 (abds v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PDIF_H") GPR:$rs1, GPR:$rs2)>;
1484+
def: Pat<(v4i16 (abdu v4i16:$rs1, v4i16:$rs2)), (!cast<Instruction>("PDIFU_H") GPR:$rs1, GPR:$rs2)>;
1485+
1486+
// Basic arithmetic patterns for v8i8 (8-bit elements in 64-bit GPR)
1487+
def: Pat<(v8i8 (add v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PADD_B") GPR:$rs1, GPR:$rs2)>;
1488+
def: Pat<(v8i8 (sub v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PSUB_B") GPR:$rs1, GPR:$rs2)>;
1489+
1490+
// Saturating add/sub patterns for v8i8
1491+
def: Pat<(v8i8 (saddsat v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PSADD_B") GPR:$rs1, GPR:$rs2)>;
1492+
def: Pat<(v8i8 (uaddsat v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PSADDU_B") GPR:$rs1, GPR:$rs2)>;
1493+
def: Pat<(v8i8 (ssubsat v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PSSUB_B") GPR:$rs1, GPR:$rs2)>;
1494+
def: Pat<(v8i8 (usubsat v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PSSUBU_B") GPR:$rs1, GPR:$rs2)>;
1495+
1496+
// Averaging patterns for v8i8
1497+
def: Pat<(v8i8 (avgfloors v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PAADD_B") GPR:$rs1, GPR:$rs2)>;
1498+
def: Pat<(v8i8 (avgflooru v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PAADDU_B") GPR:$rs1, GPR:$rs2)>;
1499+
1500+
// Averaging subtraction patterns for v8i8
1501+
// PASUB_B: signed (a - b) >> 1
1502+
def: Pat<(v8i8 (sra (sub v8i8:$rs1, v8i8:$rs2), (v8i8 (build_vector (XLenVT 1))))),
1503+
(!cast<Instruction>("PASUB_B") GPR:$rs1, GPR:$rs2)>;
1504+
// PASUBU_B: unsigned (a - b) >> 1
1505+
def: Pat<(v8i8 (srl (sub v8i8:$rs1, v8i8:$rs2), (v8i8 (build_vector (XLenVT 1))))),
1506+
(!cast<Instruction>("PASUBU_B") GPR:$rs1, GPR:$rs2)>;
1507+
1508+
// Absolute difference patterns for v8i8
1509+
def: Pat<(v8i8 (abds v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PDIF_B") GPR:$rs1, GPR:$rs2)>;
1510+
def: Pat<(v8i8 (abdu v8i8:$rs1, v8i8:$rs2)), (!cast<Instruction>("PDIFU_B") GPR:$rs1, GPR:$rs2)>;
1511+
1512+
// Load/Store patterns for v4i16 and v8i8 (use regular GPR load/store since they're in GPRs)
1513+
def : StPat<store, SD, GPR, v4i16>;
1514+
def : LdPat<load, LD, v4i16>;
1515+
def : StPat<store, SD, GPR, v8i8>;
1516+
def : LdPat<load, LD, v8i8>;
1517+
1518+
// Load/Store patterns for v2i32 (32-bit elements in 64-bit GPR)
1519+
def : StPat<store, SD, GPR, v2i32>;
1520+
def : LdPat<load, LD, v2i32>;
1521+
} // Predicates = [HasStdExtP, IsRV64]
1522+
1523+
let Predicates = [HasStdExtP, IsRV32] in {
1524+
// Basic arithmetic patterns for v2i16 (16-bit elements in 32-bit GPR)
1525+
def: Pat<(v2i16 (add v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PADD_H") GPR:$rs1, GPR:$rs2)>;
1526+
def: Pat<(v2i16 (sub v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PSUB_H") GPR:$rs1, GPR:$rs2)>;
1527+
1528+
// Saturating add/sub patterns for v2i16
1529+
def: Pat<(v2i16 (saddsat v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PSADD_H") GPR:$rs1, GPR:$rs2)>;
1530+
def: Pat<(v2i16 (uaddsat v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PSADDU_H") GPR:$rs1, GPR:$rs2)>;
1531+
def: Pat<(v2i16 (ssubsat v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PSSUB_H") GPR:$rs1, GPR:$rs2)>;
1532+
def: Pat<(v2i16 (usubsat v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PSSUBU_H") GPR:$rs1, GPR:$rs2)>;
1533+
1534+
// Averaging patterns for v2i16
1535+
def: Pat<(v2i16 (avgfloors v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PAADD_H") GPR:$rs1, GPR:$rs2)>;
1536+
def: Pat<(v2i16 (avgflooru v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PAADDU_H") GPR:$rs1, GPR:$rs2)>;
1537+
1538+
// Averaging subtraction patterns for v2i16
1539+
// PASUB_H: signed (a - b) >> 1
1540+
def: Pat<(v2i16 (sra (sub v2i16:$rs1, v2i16:$rs2), (v2i16 (build_vector (XLenVT 1))))),
1541+
(!cast<Instruction>("PASUB_H") GPR:$rs1, GPR:$rs2)>;
1542+
// PASUBU_H: unsigned (a - b) >> 1
1543+
def: Pat<(v2i16 (srl (sub v2i16:$rs1, v2i16:$rs2), (v2i16 (build_vector (XLenVT 1))))),
1544+
(!cast<Instruction>("PASUBU_H") GPR:$rs1, GPR:$rs2)>;
1545+
1546+
// Absolute difference patterns for v2i16
1547+
def: Pat<(v2i16 (abds v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PDIF_H") GPR:$rs1, GPR:$rs2)>;
1548+
def: Pat<(v2i16 (abdu v2i16:$rs1, v2i16:$rs2)), (!cast<Instruction>("PDIFU_H") GPR:$rs1, GPR:$rs2)>;
1549+
1550+
// Basic arithmetic patterns for v4i8 (8-bit elements in 32-bit GPR)
1551+
def: Pat<(v4i8 (add v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PADD_B") GPR:$rs1, GPR:$rs2)>;
1552+
def: Pat<(v4i8 (sub v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PSUB_B") GPR:$rs1, GPR:$rs2)>;
1553+
1554+
// Saturating add/sub patterns for v4i8
1555+
def: Pat<(v4i8 (saddsat v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PSADD_B") GPR:$rs1, GPR:$rs2)>;
1556+
def: Pat<(v4i8 (uaddsat v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PSADDU_B") GPR:$rs1, GPR:$rs2)>;
1557+
def: Pat<(v4i8 (ssubsat v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PSSUB_B") GPR:$rs1, GPR:$rs2)>;
1558+
def: Pat<(v4i8 (usubsat v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PSSUBU_B") GPR:$rs1, GPR:$rs2)>;
1559+
1560+
// Averaging patterns for v4i8
1561+
def: Pat<(v4i8 (avgfloors v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PAADD_B") GPR:$rs1, GPR:$rs2)>;
1562+
def: Pat<(v4i8 (avgflooru v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PAADDU_B") GPR:$rs1, GPR:$rs2)>;
1563+
1564+
// Averaging subtraction patterns for v4i8
1565+
// PASUB_B: signed (a - b) >> 1
1566+
def: Pat<(v4i8 (sra (sub v4i8:$rs1, v4i8:$rs2), (v4i8 (build_vector (XLenVT 1))))),
1567+
(!cast<Instruction>("PASUB_B") GPR:$rs1, GPR:$rs2)>;
1568+
// PASUBU_B: unsigned (a - b) >> 1
1569+
def: Pat<(v4i8 (srl (sub v4i8:$rs1, v4i8:$rs2), (v4i8 (build_vector (XLenVT 1))))),
1570+
(!cast<Instruction>("PASUBU_B") GPR:$rs1, GPR:$rs2)>;
1571+
1572+
// Absolute difference patterns for v4i8
1573+
def: Pat<(v4i8 (abds v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PDIF_B") GPR:$rs1, GPR:$rs2)>;
1574+
def: Pat<(v4i8 (abdu v4i8:$rs1, v4i8:$rs2)), (!cast<Instruction>("PDIFU_B") GPR:$rs1, GPR:$rs2)>;
1575+
1576+
// Load/Store patterns for v2i16 and v4i8 (use regular GPR load/store since they're in GPRs)
1577+
def : StPat<store, SW, GPR, v2i16>;
1578+
def : LdPat<load, LW, v2i16>;
1579+
def : StPat<store, SW, GPR, v4i8>;
1580+
def : LdPat<load, LW, v4i8>;
1581+
} // Predicates = [HasStdExtP, IsRV32]

llvm/lib/Target/RISCV/RISCVRegisterInfo.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,11 @@ class RISCVRegisterClass<list<ValueType> regTypes, int align, dag regList>
238238
}
239239

240240
class GPRRegisterClass<dag regList>
241-
: RISCVRegisterClass<[XLenVT, XLenFVT], 32, regList> {
241+
: RISCVRegisterClass<[XLenVT, XLenFVT,
242+
// P extension packed vector types:
243+
// RV32: v2i16, v4i8
244+
// RV64: v2i32, v4i16, v8i8
245+
v2i16, v4i8, v2i32, v4i16, v8i8], 32, regList> {
242246
let RegInfos = XLenRI;
243247
}
244248

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,13 @@ InstructionCost RISCVTTIImpl::getScalarizationOverhead(
969969
if (isa<ScalableVectorType>(Ty))
970970
return InstructionCost::getInvalid();
971971

972+
// TODO: Add proper cost model for P extension fixed vectors (e.g., v4i16)
973+
// For now, skip all fixed vector cost analysis when P extension is available
974+
// to avoid crashes in getMinRVVVectorSizeInBits()
975+
if (ST->hasStdExtP() && isa<FixedVectorType>(Ty)) {
976+
return 1; // Treat as single instruction cost for now
977+
}
978+
972979
// A build_vector (which is m1 sized or smaller) can be done in no
973980
// worse than one vslide1down.vx per element in the type. We could
974981
// in theory do an explode_vector in the inverse manner, but our
@@ -1625,6 +1632,13 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
16251632
if (!IsVectorType)
16261633
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
16271634

1635+
// TODO: Add proper cost model for P extension fixed vectors (e.g., v4i16)
1636+
// For now, skip all fixed vector cost analysis when P extension is available
1637+
// to avoid crashes in getMinRVVVectorSizeInBits()
1638+
if (ST->hasStdExtP() && (isa<FixedVectorType>(Dst) || isa<FixedVectorType>(Src))) {
1639+
return 1; // Treat as single instruction cost for now
1640+
}
1641+
16281642
// FIXME: Need to compute legalizing cost for illegal types. The current
16291643
// code handles only legal types and those which can be trivially
16301644
// promoted to legal.
@@ -2321,6 +2335,13 @@ InstructionCost RISCVTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
23212335
const Value *Op1) const {
23222336
assert(Val->isVectorTy() && "This must be a vector type");
23232337

2338+
// TODO: Add proper cost model for P extension fixed vectors (e.g., v4i16)
2339+
// For now, skip all fixed vector cost analysis when P extension is available
2340+
// to avoid crashes in getMinRVVVectorSizeInBits()
2341+
if (ST->hasStdExtP() && isa<FixedVectorType>(Val)) {
2342+
return 1; // Treat as single instruction cost for now
2343+
}
2344+
23242345
if (Opcode != Instruction::ExtractElement &&
23252346
Opcode != Instruction::InsertElement)
23262347
return BaseT::getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1);

0 commit comments

Comments
 (0)