-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[RISCV][llvm] Select splat_vector(constant) with PLI #168204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Default DAG combiner combine BUILD_VECTOR with same elements to SPLAT_VECTOR, we can just map constant splat to PLI if possible.
|
@llvm/pr-subscribers-backend-risc-v Author: Brandon Wu (4vtomat) ChangesDefault DAG combiner combine BUILD_VECTOR with same elements to Full diff: https://github.com/llvm/llvm-project/pull/168204.diff 3 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index 1024e55f912c7..5025122db3681 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -51,6 +51,8 @@ void RISCVDAGToDAGISel::PreprocessISelDAG() {
SDValue Result;
switch (N->getOpcode()) {
case ISD::SPLAT_VECTOR: {
+ if (Subtarget->enablePExtCodeGen())
+ break;
// Convert integer SPLAT_VECTOR to VMV_V_X_VL and floating-point
// SPLAT_VECTOR to VFMV_V_F_VL to reduce isel burden.
MVT VT = N->getSimpleValueType(0);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 38cce26e44af4..37c8d7c045443 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -525,7 +525,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::SSUBSAT, VTs, Legal);
setOperationAction({ISD::AVGFLOORS, ISD::AVGFLOORU}, VTs, Legal);
setOperationAction({ISD::ABDS, ISD::ABDU}, VTs, Legal);
- setOperationAction(ISD::BUILD_VECTOR, VTs, Custom);
+ setOperationAction(ISD::SPLAT_VECTOR, VTs, Legal);
setOperationAction(ISD::BITCAST, VTs, Custom);
setOperationAction(ISD::EXTRACT_VECTOR_ELT, VTs, Custom);
}
@@ -4437,37 +4437,6 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
MVT XLenVT = Subtarget.getXLenVT();
SDLoc DL(Op);
- // Handle P extension packed vector BUILD_VECTOR with PLI for splat constants
- if (Subtarget.enablePExtCodeGen()) {
- bool IsPExtVector =
- (VT == MVT::v2i16 || VT == MVT::v4i8) ||
- (Subtarget.is64Bit() &&
- (VT == MVT::v4i16 || VT == MVT::v8i8 || VT == MVT::v2i32));
- if (IsPExtVector) {
- if (SDValue SplatValue = cast<BuildVectorSDNode>(Op)->getSplatValue()) {
- if (auto *C = dyn_cast<ConstantSDNode>(SplatValue)) {
- int64_t SplatImm = C->getSExtValue();
- bool IsValidImm = false;
-
- // Check immediate range based on vector type
- if (VT == MVT::v8i8 || VT == MVT::v4i8) {
- // PLI_B uses 8-bit unsigned or unsigned immediate
- IsValidImm = isUInt<8>(SplatImm) || isInt<8>(SplatImm);
- if (isUInt<8>(SplatImm))
- SplatImm = (int8_t)SplatImm;
- } else {
- // PLI_H and PLI_W use 10-bit signed immediate
- IsValidImm = isInt<10>(SplatImm);
- }
-
- if (IsValidImm) {
- SDValue Imm = DAG.getSignedTargetConstant(SplatImm, DL, XLenVT);
- return DAG.getNode(RISCVISD::PLI, DL, VT, Imm);
- }
- }
- }
- }
- }
// Proper support for f16 requires Zvfh. bf16 always requires special
// handling. We need to cast the scalar to integer and create an integer
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
index 126a39996c741..2f289f89e8859 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
@@ -18,7 +18,7 @@
// Operand and SDNode transformation definitions.
//===----------------------------------------------------------------------===//
-def simm10 : RISCVSImmOp<10>, TImmLeaf<XLenVT, "return isInt<10>(Imm);">;
+def simm10 : RISCVSImmOp<10>, ImmLeaf<XLenVT, "return isInt<10>(Imm);">;
def SImm8UnsignedAsmOperand : SImmAsmOperand<8, "Unsigned"> {
let RenderMethod = "addSImm8UnsignedOperands";
@@ -26,7 +26,7 @@ def SImm8UnsignedAsmOperand : SImmAsmOperand<8, "Unsigned"> {
// A 8-bit signed immediate allowing range [-128, 255]
// but represented as [-128, 127].
-def simm8_unsigned : RISCVOp, TImmLeaf<XLenVT, "return isInt<8>(Imm);"> {
+def simm8_unsigned : RISCVOp, ImmLeaf<XLenVT, "return isInt<8>(Imm);"> {
let ParserMatchClass = SImm8UnsignedAsmOperand;
let EncoderMethod = "getImmOpValue";
let DecoderMethod = "decodeSImmOperand<8>";
@@ -1463,10 +1463,6 @@ let Predicates = [HasStdExtP, IsRV32] in {
def riscv_absw : RVSDNode<"ABSW", SDTIntUnaryOp>;
-def SDT_RISCVPLI : SDTypeProfile<1, 1, [SDTCisVec<0>,
- SDTCisInt<0>,
- SDTCisInt<1>]>;
-def riscv_pli : RVSDNode<"PLI", SDT_RISCVPLI>;
def SDT_RISCVPASUB : SDTypeProfile<1, 2, [SDTCisVec<0>,
SDTCisInt<0>,
SDTCisSameAs<0, 1>,
@@ -1519,9 +1515,9 @@ let Predicates = [HasStdExtP] in {
// 8-bit PLI SD node pattern
- def: Pat<(XLenVecI8VT (riscv_pli simm8_unsigned:$imm8)), (PLI_B simm8_unsigned:$imm8)>;
+ def: Pat<(XLenVecI8VT (splat_vector simm8_unsigned:$imm8)), (PLI_B simm8_unsigned:$imm8)>;
// 16-bit PLI SD node pattern
- def: Pat<(XLenVecI16VT (riscv_pli simm10:$imm10)), (PLI_H simm10:$imm10)>;
+ def: Pat<(XLenVecI16VT (splat_vector simm10:$imm10)), (PLI_H simm10:$imm10)>;
} // Predicates = [HasStdExtP]
@@ -1537,7 +1533,7 @@ let Predicates = [HasStdExtP, IsRV64] in {
def : PatGpr<riscv_absw, ABSW>;
// 32-bit PLI SD node pattern
- def: Pat<(v2i32 (riscv_pli simm10:$imm10)), (PLI_W simm10:$imm10)>;
+ def: Pat<(v2i32 (splat_vector simm10:$imm10)), (PLI_W simm10:$imm10)>;
// Basic 32-bit arithmetic patterns
def: Pat<(v2i32 (add GPR:$rs1, GPR:$rs2)), (PADD_W GPR:$rs1, GPR:$rs2)>;
|
| setOperationAction({ISD::AVGFLOORS, ISD::AVGFLOORU}, VTs, Legal); | ||
| setOperationAction({ISD::ABDS, ISD::ABDU}, VTs, Legal); | ||
| setOperationAction(ISD::BUILD_VECTOR, VTs, Custom); | ||
| setOperationAction(ISD::SPLAT_VECTOR, VTs, Legal); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there an instruction that can splat a scalar that isn't a constant? If you make an operation Legal, we must have isel patterns for all possible inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that makes sense!
Does it mean when we set an operation to Custom, we also have to handle operation legalization for every case, in this case SPLAT_VECTOR for non-const? If so, is it also possible to miss some opportunities for better instruction selection in pattern matching?
for example, if we try to legalize v2i16 SPLAT_VECTOR(i16 a) to something like v2i16 cast(or(a, shl(a, 16))) and we have a pattern that match v2i16 shl SPLAT_VECTOR(i16 a) -> PSLL_H a, it will fail to match unless we don't handle legalization for non-const case and consider it as legal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we use PADD.BS, PADD.HS, PADD.WS with X0 for splat_vector?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it mean theoretically we should handle both const and non-const case when we set the operation to custom?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It means we can make it legal and add patterns to use PADD.BS, PADD.HS, PADD.WS with X0 when the operand isn't a constant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see
topperc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Default DAG combiner combine BUILD_VECTOR with same elements to
SPLAT_VECTOR, we can just map constant splat to PLI if possible.