-
Notifications
You must be signed in to change notification settings - Fork 15k
[RISCV][llvm] Preliminary P extension codegen support #162668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -87,6 +87,12 @@ static cl::opt<bool> | |
| "be combined with a shift"), | ||
| cl::init(true)); | ||
|
|
||
| static cl::opt<bool> EnablePExtCodeGen( | ||
| DEBUG_TYPE "-enable-p-ext-codegen", cl::Hidden, | ||
| cl::desc("Turn on P Extension codegen(This is a temporary switch where " | ||
| "only partial codegen is currently supported."), | ||
| cl::init(false)); | ||
|
|
||
| RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, | ||
| const RISCVSubtarget &STI) | ||
| : TargetLowering(TM), Subtarget(STI) { | ||
|
|
@@ -279,6 +285,18 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, | |
| addRegisterClass(MVT::riscv_nxv32i8x2, &RISCV::VRN2M4RegClass); | ||
| } | ||
|
|
||
| // fixed vector is stored in GPRs for P extension packed operations | ||
| if (Subtarget.hasStdExtP() && EnablePExtCodeGen) { | ||
| if (Subtarget.is64Bit()) { | ||
| addRegisterClass(MVT::v2i32, &RISCV::GPRRegClass); | ||
| addRegisterClass(MVT::v4i16, &RISCV::GPRRegClass); | ||
| addRegisterClass(MVT::v8i8, &RISCV::GPRRegClass); | ||
| } else { | ||
| addRegisterClass(MVT::v2i16, &RISCV::GPRRegClass); | ||
| addRegisterClass(MVT::v4i8, &RISCV::GPRRegClass); | ||
| } | ||
| } | ||
|
|
||
| // Compute derived properties from the register classes. | ||
| computeRegisterProperties(STI.getRegisterInfo()); | ||
|
|
||
|
|
@@ -479,6 +497,37 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, | |
| ISD::FTRUNC, ISD::FRINT, ISD::FROUND, | ||
| ISD::FROUNDEVEN, ISD::FCANONICALIZE}; | ||
|
|
||
| if (Subtarget.hasStdExtP() && EnablePExtCodeGen) { | ||
| setTargetDAGCombine(ISD::TRUNCATE); | ||
| setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand); | ||
| setTruncStoreAction(MVT::v4i16, MVT::v4i8, Expand); | ||
| SmallVector<MVT, 2> VTs; | ||
| if (Subtarget.is64Bit()) { | ||
| VTs.append({MVT::v2i32, MVT::v4i16, MVT::v8i8}); | ||
| setTruncStoreAction(MVT::v2i64, MVT::v2i32, Expand); | ||
| setTruncStoreAction(MVT::v4i32, MVT::v4i16, Expand); | ||
| setTruncStoreAction(MVT::v8i16, MVT::v8i8, Expand); | ||
| setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand); | ||
| setTruncStoreAction(MVT::v4i16, MVT::v4i8, Expand); | ||
| setOperationAction(ISD::LOAD, MVT::v2i16, Custom); | ||
| setOperationAction(ISD::LOAD, MVT::v4i8, Custom); | ||
| setOperationAction(ISD::STORE, MVT::v2i16, Custom); | ||
| setOperationAction(ISD::STORE, MVT::v4i8, Custom); | ||
| } else { | ||
| VTs.append({MVT::v2i16, MVT::v4i8}); | ||
| } | ||
| setOperationAction(ISD::UADDSAT, VTs, Legal); | ||
| setOperationAction(ISD::SADDSAT, VTs, Legal); | ||
| setOperationAction(ISD::USUBSAT, VTs, Legal); | ||
| setOperationAction(ISD::SSUBSAT, VTs, Legal); | ||
| setOperationAction(ISD::SSHLSAT, VTs, Legal); | ||
| setOperationAction(ISD::USHLSAT, VTs, Legal); | ||
| setOperationAction({ISD::AVGFLOORS, ISD::AVGFLOORU}, VTs, Legal); | ||
| setOperationAction({ISD::ABDS, ISD::ABDU}, VTs, Legal); | ||
| setOperationAction(ISD::BUILD_VECTOR, VTs, Custom); | ||
| setOperationAction(ISD::BITCAST, VTs, Custom); | ||
| } | ||
|
|
||
| if (Subtarget.hasStdExtZfbfmin()) { | ||
| setOperationAction(ISD::BITCAST, MVT::i16, Custom); | ||
| setOperationAction(ISD::ConstantFP, MVT::bf16, Expand); | ||
|
|
@@ -1696,6 +1745,15 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, | |
| MaxLoadsPerMemcmp = Subtarget.getMaxLoadsPerMemcmp(/*OptSize=*/false); | ||
| } | ||
|
|
||
| TargetLoweringBase::LegalizeTypeAction | ||
| RISCVTargetLowering::getPreferredVectorAction(MVT VT) const { | ||
| if (Subtarget.hasStdExtP() && Subtarget.is64Bit()) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing |
||
| if (VT == MVT::v2i16 || VT == MVT::v4i8) | ||
| return TypeWidenVector; | ||
|
|
||
| return TargetLoweringBase::getPreferredVectorAction(VT); | ||
| } | ||
|
|
||
| EVT RISCVTargetLowering::getSetCCResultType(const DataLayout &DL, | ||
| LLVMContext &Context, | ||
| EVT VT) const { | ||
|
|
@@ -4311,6 +4369,34 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, | |
| MVT XLenVT = Subtarget.getXLenVT(); | ||
|
|
||
| SDLoc DL(Op); | ||
| // Handle P extension packed vector BUILD_VECTOR with PLI for splat constants | ||
| if (Subtarget.hasStdExtP() && EnablePExtCodeGen) { | ||
| bool IsPExtVector = | ||
| (VT == MVT::v2i16 || VT == MVT::v4i8) || | ||
| (Subtarget.is64Bit() && | ||
| (VT == MVT::v4i16 || VT == MVT::v8i8 || VT == MVT::v2i32)); | ||
| if (IsPExtVector) { | ||
| if (SDValue SplatValue = cast<BuildVectorSDNode>(Op)->getSplatValue()) { | ||
| if (auto *C = dyn_cast<ConstantSDNode>(SplatValue)) { | ||
| int64_t SplatImm = C->getSExtValue(); | ||
| bool IsValidImm = false; | ||
|
|
||
| // Check immediate range based on vector type | ||
| if (VT == MVT::v8i8 || VT == MVT::v4i8) | ||
| // PLI_B uses 8-bit unsigned immediate | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whether the immediate signed or unsigned doesn't really matter it fills the whole element. So I think you can accept isInt<8> || isUInt<8> here. But the description of |
||
| IsValidImm = isUInt<8>(SplatImm); | ||
| else | ||
| // PLI_H and PLI_W use 10-bit signed immediate | ||
| IsValidImm = isInt<10>(SplatImm); | ||
|
|
||
| if (IsValidImm) { | ||
| SDValue Imm = DAG.getConstant(SplatImm, DL, XLenVT); | ||
| return DAG.getNode(RISCVISD::PLI, DL, VT, Imm); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Proper support for f16 requires Zvfh. bf16 always requires special | ||
| // handling. We need to cast the scalar to integer and create an integer | ||
|
|
@@ -7462,6 +7548,19 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, | |
| return DAG.getNode(RISCVISD::BuildPairF64, DL, MVT::f64, Lo, Hi); | ||
| } | ||
|
|
||
| if (Subtarget.hasStdExtP()) { | ||
| bool Is32BitCast = | ||
| (VT == MVT::i32 && (Op0VT == MVT::v4i8 || Op0VT == MVT::v2i16)) || | ||
| (Op0VT == MVT::i32 && (VT == MVT::v4i8 || VT == MVT::v2i16)); | ||
| bool Is64BitCast = | ||
| (VT == MVT::i64 && (Op0VT == MVT::v8i8 || Op0VT == MVT::v4i16 || | ||
| Op0VT == MVT::v2i32)) || | ||
| (Op0VT == MVT::i64 && | ||
| (VT == MVT::v8i8 || VT == MVT::v4i16 || VT == MVT::v2i32)); | ||
| if (Is32BitCast || Is64BitCast) | ||
| return Op; | ||
| } | ||
|
|
||
| // Consider other scalar<->scalar casts as legal if the types are legal. | ||
| // Otherwise expand them. | ||
| if (!VT.isVector() && !Op0VT.isVector()) { | ||
|
|
@@ -8134,6 +8233,17 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, | |
| auto *Store = cast<StoreSDNode>(Op); | ||
| SDValue StoredVal = Store->getValue(); | ||
| EVT VT = StoredVal.getValueType(); | ||
| if (Subtarget.hasStdExtP()) { | ||
| if (VT == MVT::v2i16 || VT == MVT::v4i8) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure why you need to do this. Shouldn't the type legalizer do this? |
||
| SDValue DL(Op); | ||
| SDValue Cast = DAG.getBitcast(MVT::i32, StoredVal); | ||
| SDValue NewStore = | ||
| DAG.getStore(Store->getChain(), DL, Cast, Store->getBasePtr(), | ||
| Store->getPointerInfo(), Store->getBaseAlign(), | ||
| Store->getMemOperand()->getFlags()); | ||
| return NewStore; | ||
| } | ||
| } | ||
| if (VT == MVT::f64) { | ||
| assert(Subtarget.hasStdExtZdinx() && !Subtarget.hasStdExtZilsd() && | ||
| !Subtarget.is64Bit() && "Unexpected custom legalisation"); | ||
|
|
@@ -14561,6 +14671,19 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, | |
| return; | ||
| } | ||
|
|
||
| if (Subtarget.hasStdExtP() && Subtarget.is64Bit()) { | ||
| SDLoc DL(N); | ||
| SDValue ExtLoad = | ||
| DAG.getExtLoad(ISD::SEXTLOAD, DL, MVT::i64, Ld->getChain(), | ||
| Ld->getBasePtr(), MVT::i32, Ld->getMemOperand()); | ||
| if (N->getValueType(0) == MVT::v2i16) | ||
| Results.push_back(DAG.getBitcast(MVT::v4i16, ExtLoad)); | ||
| else if (N->getValueType(0) == MVT::v4i8) | ||
| Results.push_back(DAG.getBitcast(MVT::v8i8, ExtLoad)); | ||
| Results.push_back(ExtLoad.getValue(1)); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can't push the chain if the type isn't v2i16 or v4i8. |
||
| return; | ||
| } | ||
|
|
||
| assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() && | ||
| "Unexpected custom legalisation"); | ||
|
|
||
|
|
@@ -14889,6 +15012,24 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, | |
| Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, NewRes)); | ||
| break; | ||
| } | ||
| case RISCVISD::PASUB: | ||
| case RISCVISD::PASUBU: { | ||
| MVT VT = N->getSimpleValueType(0); | ||
| SDValue Op0 = N->getOperand(0); | ||
| SDValue Op1 = N->getOperand(1); | ||
| assert(VT == MVT::v2i16 || VT == MVT::v4i8); | ||
| MVT NewVT = MVT::v4i16; | ||
| if (VT == MVT::v4i8) | ||
| NewVT = MVT::v8i8; | ||
| Op0 = DAG.getBitcast(MVT::i32, Op0); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be using CONCAT_VECTORS with ISD::UNDEF to widen the inputs. You not go through scalar. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious about why we can't go through scalar? isn't cast simply a no-op? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The CONCAT_VECTORS should get removed by the type legalizer when it widens the surrounding operations. Leaving just v8i8 or v4i16 vector operations except for loads/stores. If you go through a bitcast to scalar, the type legalizer can't delete them. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, thanks for clarifying! |
||
| Op0 = DAG.getSExtOrTrunc(Op0, DL, MVT::i64); | ||
| Op0 = DAG.getBitcast(NewVT, Op0); | ||
| Op1 = DAG.getBitcast(MVT::i32, Op1); | ||
| Op1 = DAG.getSExtOrTrunc(Op1, DL, MVT::i64); | ||
| Op1 = DAG.getBitcast(NewVT, Op1); | ||
| Results.push_back(DAG.getNode(N->getOpcode(), DL, NewVT, {Op0, Op1})); | ||
| return; | ||
| } | ||
| case ISD::EXTRACT_VECTOR_ELT: { | ||
| // Custom-legalize an EXTRACT_VECTOR_ELT where XLEN<SEW, as the SEW element | ||
| // type is illegal (currently only vXi64 RV32). | ||
|
|
@@ -15996,11 +16137,88 @@ static SDValue combineTruncSelectToSMaxUSat(SDNode *N, SelectionDAG &DAG) { | |
| return DAG.getNode(ISD::TRUNCATE, DL, VT, Min); | ||
| } | ||
|
|
||
| // Handle P extension averaging subtraction pattern: | ||
| // (vXiY (trunc (srl (sub ([s|z]ext vXiY:$a), ([s|z]ext vXiY:$b)), 1))) | ||
| // -> PASUB/PASUBU | ||
| static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG, | ||
| const RISCVSubtarget &Subtarget) { | ||
| SDValue N0 = N->getOperand(0); | ||
| EVT VT = N->getValueType(0); | ||
| if (!Subtarget.hasStdExtP() || !VT.isFixedLengthVector()) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Didn't you already check hasStdExtP at the caller? |
||
| return SDValue(); | ||
|
|
||
| if (N0.getOpcode() != ISD::SRL) | ||
| return SDValue(); | ||
|
|
||
| // Check if shift amount is 1 | ||
| SDValue ShAmt = N0.getOperand(1); | ||
| if (ShAmt.getOpcode() != ISD::BUILD_VECTOR) | ||
| return SDValue(); | ||
|
|
||
| BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(ShAmt.getNode()); | ||
| if (!BV) | ||
| return SDValue(); | ||
| SDValue Splat = BV->getSplatValue(); | ||
| if (!Splat) | ||
| return SDValue(); | ||
| ConstantSDNode *C = dyn_cast<ConstantSDNode>(Splat); | ||
| if (!C) | ||
| return SDValue(); | ||
| if (C->getZExtValue() != 1) | ||
| return SDValue(); | ||
|
|
||
| // Check for SUB operation | ||
| SDValue Sub = N0.getOperand(0); | ||
| if (Sub.getOpcode() != ISD::SUB) | ||
| return SDValue(); | ||
|
|
||
| SDValue LHS = Sub.getOperand(0); | ||
| SDValue RHS = Sub.getOperand(1); | ||
|
|
||
| // Check if both operands are sign/zero extends from the target | ||
| // type | ||
| bool IsSignExt = LHS.getOpcode() == ISD::SIGN_EXTEND && | ||
| RHS.getOpcode() == ISD::SIGN_EXTEND; | ||
| bool IsZeroExt = LHS.getOpcode() == ISD::ZERO_EXTEND && | ||
| RHS.getOpcode() == ISD::ZERO_EXTEND; | ||
|
|
||
| if (!IsSignExt && !IsZeroExt) | ||
| return SDValue(); | ||
|
|
||
| SDValue A = LHS.getOperand(0); | ||
| SDValue B = RHS.getOperand(0); | ||
|
|
||
| // Check if the extends are from our target vector type | ||
| if (A.getValueType() != VT || B.getValueType() != VT) | ||
| return SDValue(); | ||
|
|
||
| // Determine the instruction based on type and signedness | ||
| unsigned Opc; | ||
| MVT VecVT = VT.getSimpleVT(); | ||
| if (VecVT == MVT::v4i16 || VecVT == MVT::v2i16 || VecVT == MVT::v8i8 || | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check type at the beginning? |
||
| VecVT == MVT::v4i8 || VecVT == MVT::v2i32) { | ||
| if (IsSignExt) | ||
| Opc = RISCVISD::PASUB; | ||
| else if (IsZeroExt) | ||
| Opc = RISCVISD::PASUBU; | ||
| else | ||
| return SDValue(); | ||
| } else { | ||
| return SDValue(); | ||
| } | ||
|
|
||
| // Create the machine node directly | ||
| return DAG.getNode(Opc, SDLoc(N), VT, {A, B}); | ||
| } | ||
|
|
||
| static SDValue performTRUNCATECombine(SDNode *N, SelectionDAG &DAG, | ||
| const RISCVSubtarget &Subtarget) { | ||
| SDValue N0 = N->getOperand(0); | ||
| EVT VT = N->getValueType(0); | ||
|
|
||
| if (Subtarget.hasStdExtP() && VT.isFixedLengthVector() && EnablePExtCodeGen) | ||
| return combinePExtTruncate(N, DAG, Subtarget); | ||
|
|
||
| // Pre-promote (i1 (truncate (srl X, Y))) on RV64 with Zbs without zero | ||
| // extending X. This is safe since we only need the LSB after the shift and | ||
| // shift amounts larger than 31 would produce poison. If we wait until | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need this change?