Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,7 @@ struct RISCVOperand final : public MCParsedAsmOperand {

bool isSImm5() const { return isSImm<5>(); }
bool isSImm6() const { return isSImm<6>(); }
bool isSImm8() const { return isSImm<8>(); }
bool isSImm10() const { return isSImm<10>(); }
bool isSImm11() const { return isSImm<11>(); }
bool isSImm12() const { return isSImm<12>(); }
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2471,6 +2471,14 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
CurDAG->RemoveDeadNode(Node);
return;
}
if (Subtarget->hasStdExtP()) {
if (((VT == MVT::v4i16 || VT == MVT::v8i8) && SrcVT == MVT::i64) ||
((SrcVT == MVT::v4i16 || SrcVT == MVT::v8i8) && VT == MVT::i64)) {
ReplaceUses(SDValue(Node, 0), Node->getOperand(0));
CurDAG->RemoveDeadNode(Node);
}
return;
}
break;
}
case ISD::INSERT_SUBVECTOR:
Expand Down
218 changes: 218 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ static cl::opt<bool>
"be combined with a shift"),
cl::init(true));

static cl::opt<bool> EnablePExtCodeGen(
DEBUG_TYPE "-enable-p-ext-codegen", cl::Hidden,
cl::desc("Turn on P Extension codegen(This is a temporary switch where "
"only partial codegen is currently supported."),
cl::init(false));

RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
const RISCVSubtarget &STI)
: TargetLowering(TM), Subtarget(STI) {
Expand Down Expand Up @@ -279,6 +285,18 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
addRegisterClass(MVT::riscv_nxv32i8x2, &RISCV::VRN2M4RegClass);
}

// fixed vector is stored in GPRs for P extension packed operations
if (Subtarget.hasStdExtP() && EnablePExtCodeGen) {
if (Subtarget.is64Bit()) {
addRegisterClass(MVT::v2i32, &RISCV::GPRRegClass);
addRegisterClass(MVT::v4i16, &RISCV::GPRRegClass);
addRegisterClass(MVT::v8i8, &RISCV::GPRRegClass);
} else {
addRegisterClass(MVT::v2i16, &RISCV::GPRRegClass);
addRegisterClass(MVT::v4i8, &RISCV::GPRRegClass);
}
}

// Compute derived properties from the register classes.
computeRegisterProperties(STI.getRegisterInfo());

Expand Down Expand Up @@ -479,6 +497,37 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::FTRUNC, ISD::FRINT, ISD::FROUND,
ISD::FROUNDEVEN, ISD::FCANONICALIZE};

if (Subtarget.hasStdExtP() && EnablePExtCodeGen) {
setTargetDAGCombine(ISD::TRUNCATE);
setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand);
setTruncStoreAction(MVT::v4i16, MVT::v4i8, Expand);
SmallVector<MVT, 2> VTs;
if (Subtarget.is64Bit()) {
VTs.append({MVT::v2i32, MVT::v4i16, MVT::v8i8});
setTruncStoreAction(MVT::v2i64, MVT::v2i32, Expand);
setTruncStoreAction(MVT::v4i32, MVT::v4i16, Expand);
setTruncStoreAction(MVT::v8i16, MVT::v8i8, Expand);
setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand);
setTruncStoreAction(MVT::v4i16, MVT::v4i8, Expand);
setOperationAction(ISD::LOAD, MVT::v2i16, Custom);
setOperationAction(ISD::LOAD, MVT::v4i8, Custom);
setOperationAction(ISD::STORE, MVT::v2i16, Custom);
setOperationAction(ISD::STORE, MVT::v4i8, Custom);
} else {
VTs.append({MVT::v2i16, MVT::v4i8});
}
setOperationAction(ISD::UADDSAT, VTs, Legal);
setOperationAction(ISD::SADDSAT, VTs, Legal);
setOperationAction(ISD::USUBSAT, VTs, Legal);
setOperationAction(ISD::SSUBSAT, VTs, Legal);
setOperationAction(ISD::SSHLSAT, VTs, Legal);
setOperationAction(ISD::USHLSAT, VTs, Legal);
setOperationAction({ISD::AVGFLOORS, ISD::AVGFLOORU}, VTs, Legal);
setOperationAction({ISD::ABDS, ISD::ABDU}, VTs, Legal);
setOperationAction(ISD::BUILD_VECTOR, VTs, Custom);
setOperationAction(ISD::BITCAST, VTs, Custom);
}

if (Subtarget.hasStdExtZfbfmin()) {
setOperationAction(ISD::BITCAST, MVT::i16, Custom);
setOperationAction(ISD::ConstantFP, MVT::bf16, Expand);
Expand Down Expand Up @@ -1696,6 +1745,15 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
MaxLoadsPerMemcmp = Subtarget.getMaxLoadsPerMemcmp(/*OptSize=*/false);
}

TargetLoweringBase::LegalizeTypeAction
RISCVTargetLowering::getPreferredVectorAction(MVT VT) const {
if (Subtarget.hasStdExtP() && Subtarget.is64Bit())
if (VT == MVT::v2i16 || VT == MVT::v4i8)
return TypeWidenVector;

return TargetLoweringBase::getPreferredVectorAction(VT);
}

EVT RISCVTargetLowering::getSetCCResultType(const DataLayout &DL,
LLVMContext &Context,
EVT VT) const {
Expand Down Expand Up @@ -4311,6 +4369,34 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
MVT XLenVT = Subtarget.getXLenVT();

SDLoc DL(Op);
// Handle P extension packed vector BUILD_VECTOR with PLI for splat constants
if (Subtarget.hasStdExtP() && EnablePExtCodeGen) {
bool IsPExtVector =
(VT == MVT::v2i16 || VT == MVT::v4i8) ||
(Subtarget.is64Bit() &&
(VT == MVT::v4i16 || VT == MVT::v8i8 || VT == MVT::v2i32));
if (IsPExtVector) {
if (SDValue SplatValue = cast<BuildVectorSDNode>(Op)->getSplatValue()) {
if (auto *C = dyn_cast<ConstantSDNode>(SplatValue)) {
int64_t SplatImm = C->getSExtValue();
bool IsValidImm = false;

// Check immediate range based on vector type
if (VT == MVT::v8i8 || VT == MVT::v4i8)
// PLI_B uses 8-bit unsigned immediate
IsValidImm = isUInt<8>(SplatImm);
else
// PLI_H and PLI_W use 10-bit signed immediate
IsValidImm = isInt<10>(SplatImm);

if (IsValidImm) {
SDValue Imm = DAG.getConstant(SplatImm, DL, XLenVT);
return DAG.getNode(RISCVISD::PLI, DL, VT, Imm);
}
}
}
}
}

// Proper support for f16 requires Zvfh. bf16 always requires special
// handling. We need to cast the scalar to integer and create an integer
Expand Down Expand Up @@ -7462,6 +7548,19 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return DAG.getNode(RISCVISD::BuildPairF64, DL, MVT::f64, Lo, Hi);
}

if (Subtarget.hasStdExtP()) {
bool Is32BitCast =
(VT == MVT::i32 && (Op0VT == MVT::v4i8 || Op0VT == MVT::v2i16)) ||
(Op0VT == MVT::i32 && (VT == MVT::v4i8 || VT == MVT::v2i16));
bool Is64BitCast =
(VT == MVT::i64 && (Op0VT == MVT::v8i8 || Op0VT == MVT::v4i16 ||
Op0VT == MVT::v2i32)) ||
(Op0VT == MVT::i64 &&
(VT == MVT::v8i8 || VT == MVT::v4i16 || VT == MVT::v2i32));
if (Is32BitCast || Is64BitCast)
return Op;
}

// Consider other scalar<->scalar casts as legal if the types are legal.
// Otherwise expand them.
if (!VT.isVector() && !Op0VT.isVector()) {
Expand Down Expand Up @@ -8134,6 +8233,17 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
auto *Store = cast<StoreSDNode>(Op);
SDValue StoredVal = Store->getValue();
EVT VT = StoredVal.getValueType();
if (Subtarget.hasStdExtP()) {
if (VT == MVT::v2i16 || VT == MVT::v4i8) {
SDValue DL(Op);
SDValue Cast = DAG.getBitcast(MVT::i32, StoredVal);
SDValue NewStore =
DAG.getStore(Store->getChain(), DL, Cast, Store->getBasePtr(),
Store->getPointerInfo(), Store->getBaseAlign(),
Store->getMemOperand()->getFlags());
return NewStore;
}
}
if (VT == MVT::f64) {
assert(Subtarget.hasStdExtZdinx() && !Subtarget.hasStdExtZilsd() &&
!Subtarget.is64Bit() && "Unexpected custom legalisation");
Expand Down Expand Up @@ -14561,6 +14671,19 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
return;
}

if (Subtarget.hasStdExtP() && Subtarget.is64Bit()) {
SDLoc DL(N);
SDValue ExtLoad =
DAG.getExtLoad(ISD::SEXTLOAD, DL, MVT::i64, Ld->getChain(),
Ld->getBasePtr(), MVT::i32, Ld->getMemOperand());
if (N->getValueType(0) == MVT::v2i16)
Results.push_back(DAG.getBitcast(MVT::v4i16, ExtLoad));
else if (N->getValueType(0) == MVT::v4i8)
Results.push_back(DAG.getBitcast(MVT::v8i8, ExtLoad));
Results.push_back(ExtLoad.getValue(1));
return;
}

assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
"Unexpected custom legalisation");

Expand Down Expand Up @@ -14889,6 +15012,24 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, NewRes));
break;
}
case RISCVISD::PASUB:
case RISCVISD::PASUBU: {
MVT VT = N->getSimpleValueType(0);
SDValue Op0 = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
assert(VT == MVT::v2i16 || VT == MVT::v4i8);
MVT NewVT = MVT::v4i16;
if (VT == MVT::v4i8)
NewVT = MVT::v8i8;
Op0 = DAG.getBitcast(MVT::i32, Op0);
Op0 = DAG.getSExtOrTrunc(Op0, DL, MVT::i64);
Op0 = DAG.getBitcast(NewVT, Op0);
Op1 = DAG.getBitcast(MVT::i32, Op1);
Op1 = DAG.getSExtOrTrunc(Op1, DL, MVT::i64);
Op1 = DAG.getBitcast(NewVT, Op1);
Results.push_back(DAG.getNode(N->getOpcode(), DL, NewVT, {Op0, Op1}));
return;
}
case ISD::EXTRACT_VECTOR_ELT: {
// Custom-legalize an EXTRACT_VECTOR_ELT where XLEN<SEW, as the SEW element
// type is illegal (currently only vXi64 RV32).
Expand Down Expand Up @@ -15996,11 +16137,88 @@ static SDValue combineTruncSelectToSMaxUSat(SDNode *N, SelectionDAG &DAG) {
return DAG.getNode(ISD::TRUNCATE, DL, VT, Min);
}

// Handle P extension averaging subtraction pattern:
// (vXiY (trunc (srl (sub ([s|z]ext vXiY:$a), ([s|z]ext vXiY:$b)), 1)))
// -> PASUB/PASUBU
static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
if (!Subtarget.hasStdExtP() || !VT.isFixedLengthVector())
return SDValue();

if (N0.getOpcode() != ISD::SRL)
return SDValue();

// Check if shift amount is 1
SDValue ShAmt = N0.getOperand(1);
if (ShAmt.getOpcode() != ISD::BUILD_VECTOR)
return SDValue();

BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(ShAmt.getNode());
if (!BV)
return SDValue();
SDValue Splat = BV->getSplatValue();
if (!Splat)
return SDValue();
ConstantSDNode *C = dyn_cast<ConstantSDNode>(Splat);
if (!C)
return SDValue();
if (C->getZExtValue() != 1)
return SDValue();

// Check for SUB operation
SDValue Sub = N0.getOperand(0);
if (Sub.getOpcode() != ISD::SUB)
return SDValue();

SDValue LHS = Sub.getOperand(0);
SDValue RHS = Sub.getOperand(1);

// Check if both operands are sign/zero extends from the target
// type
bool IsSignExt = LHS.getOpcode() == ISD::SIGN_EXTEND &&
RHS.getOpcode() == ISD::SIGN_EXTEND;
bool IsZeroExt = LHS.getOpcode() == ISD::ZERO_EXTEND &&
RHS.getOpcode() == ISD::ZERO_EXTEND;

if (!IsSignExt && !IsZeroExt)
return SDValue();

SDValue A = LHS.getOperand(0);
SDValue B = RHS.getOperand(0);

// Check if the extends are from our target vector type
if (A.getValueType() != VT || B.getValueType() != VT)
return SDValue();

// Determine the instruction based on type and signedness
unsigned Opc;
MVT VecVT = VT.getSimpleVT();
if (VecVT == MVT::v4i16 || VecVT == MVT::v2i16 || VecVT == MVT::v8i8 ||
VecVT == MVT::v4i8 || VecVT == MVT::v2i32) {
if (IsSignExt)
Opc = RISCVISD::PASUB;
else if (IsZeroExt)
Opc = RISCVISD::PASUBU;
else
return SDValue();
} else {
return SDValue();
}

// Create the machine node directly
return DAG.getNode(Opc, SDLoc(N), VT, {A, B});
}

static SDValue performTRUNCATECombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);

if (Subtarget.hasStdExtP() && VT.isFixedLengthVector() && EnablePExtCodeGen)
return combinePExtTruncate(N, DAG, Subtarget);

// Pre-promote (i1 (truncate (srl X, Y))) on RV64 with Zbs without zero
// extending X. This is safe since we only need the LSB after the shift and
// shift amounts larger than 31 would produce poison. If we wait until
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ class RISCVTargetLowering : public TargetLowering {

bool preferScalarizeSplat(SDNode *N) const override;

/// Customize the preferred legalization strategy for certain types.
LegalizeTypeAction getPreferredVectorAction(MVT VT) const override;

bool softPromoteHalfType() const override { return true; }

/// Return the register type for a given MVT, ensuring vectors are treated
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2895,6 +2895,12 @@ bool RISCVInstrInfo::verifyInstruction(const MachineInstr &MI,
case RISCVOp::OPERAND_UIMM9_LSB000:
Ok = isShiftedUInt<6, 3>(Imm);
break;
case RISCVOp::OPERAND_SIMM8_UNSIGNED:
Ok = isInt<8>(Imm);
break;
case RISCVOp::OPERAND_SIMM10_UNSIGNED:
Ok = isInt<10>(Imm);
break;
case RISCVOp::OPERAND_SIMM10_LSB0000_NONZERO:
Ok = isShiftedInt<6, 4>(Imm) && (Imm != 0);
break;
Expand All @@ -2916,6 +2922,8 @@ bool RISCVInstrInfo::verifyInstruction(const MachineInstr &MI,
// clang-format off
CASE_OPERAND_SIMM(5)
CASE_OPERAND_SIMM(6)
CASE_OPERAND_SIMM(8)
CASE_OPERAND_SIMM(10)
CASE_OPERAND_SIMM(11)
CASE_OPERAND_SIMM(12)
CASE_OPERAND_SIMM(26)
Expand Down
Loading
Loading