Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,7 @@ struct RISCVOperand final : public MCParsedAsmOperand {

bool isSImm5() const { return isSImm<5>(); }
bool isSImm6() const { return isSImm<6>(); }
bool isSImm8() const { return isSImm<8>(); }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need this change?

bool isSImm10() const { return isSImm<10>(); }
bool isSImm11() const { return isSImm<11>(); }
bool isSImm12() const { return isSImm<12>(); }
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2471,6 +2471,14 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
CurDAG->RemoveDeadNode(Node);
return;
}
if (Subtarget->hasStdExtP()) {
if (((VT == MVT::v4i16 || VT == MVT::v8i8) && SrcVT == MVT::i64) ||
((SrcVT == MVT::v4i16 || SrcVT == MVT::v8i8) && VT == MVT::i64)) {
ReplaceUses(SDValue(Node, 0), Node->getOperand(0));
CurDAG->RemoveDeadNode(Node);
}
return;
}
break;
}
case ISD::INSERT_SUBVECTOR:
Expand Down
218 changes: 218 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ static cl::opt<bool>
"be combined with a shift"),
cl::init(true));

static cl::opt<bool> EnablePExtCodeGen(
DEBUG_TYPE "-enable-p-ext-codegen", cl::Hidden,
cl::desc("Turn on P Extension codegen(This is a temporary switch where "
"only partial codegen is currently supported."),
cl::init(false));

RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
const RISCVSubtarget &STI)
: TargetLowering(TM), Subtarget(STI) {
Expand Down Expand Up @@ -279,6 +285,18 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
addRegisterClass(MVT::riscv_nxv32i8x2, &RISCV::VRN2M4RegClass);
}

// fixed vector is stored in GPRs for P extension packed operations
if (Subtarget.hasStdExtP() && EnablePExtCodeGen) {
if (Subtarget.is64Bit()) {
addRegisterClass(MVT::v2i32, &RISCV::GPRRegClass);
addRegisterClass(MVT::v4i16, &RISCV::GPRRegClass);
addRegisterClass(MVT::v8i8, &RISCV::GPRRegClass);
} else {
addRegisterClass(MVT::v2i16, &RISCV::GPRRegClass);
addRegisterClass(MVT::v4i8, &RISCV::GPRRegClass);
}
}

// Compute derived properties from the register classes.
computeRegisterProperties(STI.getRegisterInfo());

Expand Down Expand Up @@ -479,6 +497,37 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::FTRUNC, ISD::FRINT, ISD::FROUND,
ISD::FROUNDEVEN, ISD::FCANONICALIZE};

if (Subtarget.hasStdExtP() && EnablePExtCodeGen) {
setTargetDAGCombine(ISD::TRUNCATE);
setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand);
setTruncStoreAction(MVT::v4i16, MVT::v4i8, Expand);
SmallVector<MVT, 2> VTs;
if (Subtarget.is64Bit()) {
VTs.append({MVT::v2i32, MVT::v4i16, MVT::v8i8});
setTruncStoreAction(MVT::v2i64, MVT::v2i32, Expand);
setTruncStoreAction(MVT::v4i32, MVT::v4i16, Expand);
setTruncStoreAction(MVT::v8i16, MVT::v8i8, Expand);
setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand);
setTruncStoreAction(MVT::v4i16, MVT::v4i8, Expand);
setOperationAction(ISD::LOAD, MVT::v2i16, Custom);
setOperationAction(ISD::LOAD, MVT::v4i8, Custom);
setOperationAction(ISD::STORE, MVT::v2i16, Custom);
setOperationAction(ISD::STORE, MVT::v4i8, Custom);
} else {
VTs.append({MVT::v2i16, MVT::v4i8});
}
setOperationAction(ISD::UADDSAT, VTs, Legal);
setOperationAction(ISD::SADDSAT, VTs, Legal);
setOperationAction(ISD::USUBSAT, VTs, Legal);
setOperationAction(ISD::SSUBSAT, VTs, Legal);
setOperationAction(ISD::SSHLSAT, VTs, Legal);
setOperationAction(ISD::USHLSAT, VTs, Legal);
setOperationAction({ISD::AVGFLOORS, ISD::AVGFLOORU}, VTs, Legal);
setOperationAction({ISD::ABDS, ISD::ABDU}, VTs, Legal);
setOperationAction(ISD::BUILD_VECTOR, VTs, Custom);
setOperationAction(ISD::BITCAST, VTs, Custom);
}

if (Subtarget.hasStdExtZfbfmin()) {
setOperationAction(ISD::BITCAST, MVT::i16, Custom);
setOperationAction(ISD::ConstantFP, MVT::bf16, Expand);
Expand Down Expand Up @@ -1696,6 +1745,15 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
MaxLoadsPerMemcmp = Subtarget.getMaxLoadsPerMemcmp(/*OptSize=*/false);
}

TargetLoweringBase::LegalizeTypeAction
RISCVTargetLowering::getPreferredVectorAction(MVT VT) const {
if (Subtarget.hasStdExtP() && Subtarget.is64Bit())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing EnablePExtCodeGen

if (VT == MVT::v2i16 || VT == MVT::v4i8)
return TypeWidenVector;

return TargetLoweringBase::getPreferredVectorAction(VT);
}

EVT RISCVTargetLowering::getSetCCResultType(const DataLayout &DL,
LLVMContext &Context,
EVT VT) const {
Expand Down Expand Up @@ -4311,6 +4369,34 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
MVT XLenVT = Subtarget.getXLenVT();

SDLoc DL(Op);
// Handle P extension packed vector BUILD_VECTOR with PLI for splat constants
if (Subtarget.hasStdExtP() && EnablePExtCodeGen) {
bool IsPExtVector =
(VT == MVT::v2i16 || VT == MVT::v4i8) ||
(Subtarget.is64Bit() &&
(VT == MVT::v4i16 || VT == MVT::v8i8 || VT == MVT::v2i32));
if (IsPExtVector) {
if (SDValue SplatValue = cast<BuildVectorSDNode>(Op)->getSplatValue()) {
if (auto *C = dyn_cast<ConstantSDNode>(SplatValue)) {
int64_t SplatImm = C->getSExtValue();
bool IsValidImm = false;

// Check immediate range based on vector type
if (VT == MVT::v8i8 || VT == MVT::v4i8)
// PLI_B uses 8-bit unsigned immediate
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether the immediate signed or unsigned doesn't really matter it fills the whole element. So I think you can accept isInt<8> || isUInt<8> here.

But the description of simm8_unsigned says that the canonical form is [-128,127] so you need to convert UInt8 to Int8.

IsValidImm = isUInt<8>(SplatImm);
else
// PLI_H and PLI_W use 10-bit signed immediate
IsValidImm = isInt<10>(SplatImm);

if (IsValidImm) {
SDValue Imm = DAG.getConstant(SplatImm, DL, XLenVT);
return DAG.getNode(RISCVISD::PLI, DL, VT, Imm);
}
}
}
}
}

// Proper support for f16 requires Zvfh. bf16 always requires special
// handling. We need to cast the scalar to integer and create an integer
Expand Down Expand Up @@ -7462,6 +7548,19 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return DAG.getNode(RISCVISD::BuildPairF64, DL, MVT::f64, Lo, Hi);
}

if (Subtarget.hasStdExtP()) {
bool Is32BitCast =
(VT == MVT::i32 && (Op0VT == MVT::v4i8 || Op0VT == MVT::v2i16)) ||
(Op0VT == MVT::i32 && (VT == MVT::v4i8 || VT == MVT::v2i16));
bool Is64BitCast =
(VT == MVT::i64 && (Op0VT == MVT::v8i8 || Op0VT == MVT::v4i16 ||
Op0VT == MVT::v2i32)) ||
(Op0VT == MVT::i64 &&
(VT == MVT::v8i8 || VT == MVT::v4i16 || VT == MVT::v2i32));
if (Is32BitCast || Is64BitCast)
return Op;
}

// Consider other scalar<->scalar casts as legal if the types are legal.
// Otherwise expand them.
if (!VT.isVector() && !Op0VT.isVector()) {
Expand Down Expand Up @@ -8134,6 +8233,17 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
auto *Store = cast<StoreSDNode>(Op);
SDValue StoredVal = Store->getValue();
EVT VT = StoredVal.getValueType();
if (Subtarget.hasStdExtP()) {
if (VT == MVT::v2i16 || VT == MVT::v4i8) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why you need to do this. Shouldn't the type legalizer do this?

SDValue DL(Op);
SDValue Cast = DAG.getBitcast(MVT::i32, StoredVal);
SDValue NewStore =
DAG.getStore(Store->getChain(), DL, Cast, Store->getBasePtr(),
Store->getPointerInfo(), Store->getBaseAlign(),
Store->getMemOperand()->getFlags());
return NewStore;
}
}
if (VT == MVT::f64) {
assert(Subtarget.hasStdExtZdinx() && !Subtarget.hasStdExtZilsd() &&
!Subtarget.is64Bit() && "Unexpected custom legalisation");
Expand Down Expand Up @@ -14561,6 +14671,19 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
return;
}

if (Subtarget.hasStdExtP() && Subtarget.is64Bit()) {
SDLoc DL(N);
SDValue ExtLoad =
DAG.getExtLoad(ISD::SEXTLOAD, DL, MVT::i64, Ld->getChain(),
Ld->getBasePtr(), MVT::i32, Ld->getMemOperand());
if (N->getValueType(0) == MVT::v2i16)
Results.push_back(DAG.getBitcast(MVT::v4i16, ExtLoad));
else if (N->getValueType(0) == MVT::v4i8)
Results.push_back(DAG.getBitcast(MVT::v8i8, ExtLoad));
Results.push_back(ExtLoad.getValue(1));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't push the chain if the type isn't v2i16 or v4i8.

return;
}

assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
"Unexpected custom legalisation");

Expand Down Expand Up @@ -14889,6 +15012,24 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, NewRes));
break;
}
case RISCVISD::PASUB:
case RISCVISD::PASUBU: {
MVT VT = N->getSimpleValueType(0);
SDValue Op0 = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
assert(VT == MVT::v2i16 || VT == MVT::v4i8);
MVT NewVT = MVT::v4i16;
if (VT == MVT::v4i8)
NewVT = MVT::v8i8;
Op0 = DAG.getBitcast(MVT::i32, Op0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be using CONCAT_VECTORS with ISD::UNDEF to widen the inputs. You not go through scalar.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious about why we can't go through scalar? isn't cast simply a no-op?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CONCAT_VECTORS should get removed by the type legalizer when it widens the surrounding operations. Leaving just v8i8 or v4i16 vector operations except for loads/stores. If you go through a bitcast to scalar, the type legalizer can't delete them.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for clarifying!

Op0 = DAG.getSExtOrTrunc(Op0, DL, MVT::i64);
Op0 = DAG.getBitcast(NewVT, Op0);
Op1 = DAG.getBitcast(MVT::i32, Op1);
Op1 = DAG.getSExtOrTrunc(Op1, DL, MVT::i64);
Op1 = DAG.getBitcast(NewVT, Op1);
Results.push_back(DAG.getNode(N->getOpcode(), DL, NewVT, {Op0, Op1}));
return;
}
case ISD::EXTRACT_VECTOR_ELT: {
// Custom-legalize an EXTRACT_VECTOR_ELT where XLEN<SEW, as the SEW element
// type is illegal (currently only vXi64 RV32).
Expand Down Expand Up @@ -15996,11 +16137,88 @@ static SDValue combineTruncSelectToSMaxUSat(SDNode *N, SelectionDAG &DAG) {
return DAG.getNode(ISD::TRUNCATE, DL, VT, Min);
}

// Handle P extension averaging subtraction pattern:
// (vXiY (trunc (srl (sub ([s|z]ext vXiY:$a), ([s|z]ext vXiY:$b)), 1)))
// -> PASUB/PASUBU
static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
if (!Subtarget.hasStdExtP() || !VT.isFixedLengthVector())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't you already check hasStdExtP at the caller?

return SDValue();

if (N0.getOpcode() != ISD::SRL)
return SDValue();

// Check if shift amount is 1
SDValue ShAmt = N0.getOperand(1);
if (ShAmt.getOpcode() != ISD::BUILD_VECTOR)
return SDValue();

BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(ShAmt.getNode());
if (!BV)
return SDValue();
SDValue Splat = BV->getSplatValue();
if (!Splat)
return SDValue();
ConstantSDNode *C = dyn_cast<ConstantSDNode>(Splat);
if (!C)
return SDValue();
if (C->getZExtValue() != 1)
return SDValue();

// Check for SUB operation
SDValue Sub = N0.getOperand(0);
if (Sub.getOpcode() != ISD::SUB)
return SDValue();

SDValue LHS = Sub.getOperand(0);
SDValue RHS = Sub.getOperand(1);

// Check if both operands are sign/zero extends from the target
// type
bool IsSignExt = LHS.getOpcode() == ISD::SIGN_EXTEND &&
RHS.getOpcode() == ISD::SIGN_EXTEND;
bool IsZeroExt = LHS.getOpcode() == ISD::ZERO_EXTEND &&
RHS.getOpcode() == ISD::ZERO_EXTEND;

if (!IsSignExt && !IsZeroExt)
return SDValue();

SDValue A = LHS.getOperand(0);
SDValue B = RHS.getOperand(0);

// Check if the extends are from our target vector type
if (A.getValueType() != VT || B.getValueType() != VT)
return SDValue();

// Determine the instruction based on type and signedness
unsigned Opc;
MVT VecVT = VT.getSimpleVT();
if (VecVT == MVT::v4i16 || VecVT == MVT::v2i16 || VecVT == MVT::v8i8 ||
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check type at the beginning?

VecVT == MVT::v4i8 || VecVT == MVT::v2i32) {
if (IsSignExt)
Opc = RISCVISD::PASUB;
else if (IsZeroExt)
Opc = RISCVISD::PASUBU;
else
return SDValue();
} else {
return SDValue();
}

// Create the machine node directly
return DAG.getNode(Opc, SDLoc(N), VT, {A, B});
}

static SDValue performTRUNCATECombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);

if (Subtarget.hasStdExtP() && VT.isFixedLengthVector() && EnablePExtCodeGen)
return combinePExtTruncate(N, DAG, Subtarget);

// Pre-promote (i1 (truncate (srl X, Y))) on RV64 with Zbs without zero
// extending X. This is safe since we only need the LSB after the shift and
// shift amounts larger than 31 would produce poison. If we wait until
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ class RISCVTargetLowering : public TargetLowering {

bool preferScalarizeSplat(SDNode *N) const override;

/// Customize the preferred legalization strategy for certain types.
LegalizeTypeAction getPreferredVectorAction(MVT VT) const override;

bool softPromoteHalfType() const override { return true; }

/// Return the register type for a given MVT, ensuring vectors are treated
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2895,6 +2895,12 @@ bool RISCVInstrInfo::verifyInstruction(const MachineInstr &MI,
case RISCVOp::OPERAND_UIMM9_LSB000:
Ok = isShiftedUInt<6, 3>(Imm);
break;
case RISCVOp::OPERAND_SIMM8_UNSIGNED:
Ok = isInt<8>(Imm);
break;
case RISCVOp::OPERAND_SIMM10_UNSIGNED:
Ok = isInt<10>(Imm);
break;
case RISCVOp::OPERAND_SIMM10_LSB0000_NONZERO:
Ok = isShiftedInt<6, 4>(Imm) && (Imm != 0);
break;
Expand All @@ -2916,6 +2922,8 @@ bool RISCVInstrInfo::verifyInstruction(const MachineInstr &MI,
// clang-format off
CASE_OPERAND_SIMM(5)
CASE_OPERAND_SIMM(6)
CASE_OPERAND_SIMM(8)
CASE_OPERAND_SIMM(10)
CASE_OPERAND_SIMM(11)
CASE_OPERAND_SIMM(12)
CASE_OPERAND_SIMM(26)
Expand Down
Loading
Loading