diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h index d21cc962da46c..51b5a8782409a 100644 --- a/llvm/include/llvm/CodeGen/SDPatternMatch.h +++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h @@ -514,6 +514,12 @@ m_InsertElt(const T0_P &Vec, const T1_P &Val, const T2_P &Idx) { Idx); } +template +inline TernaryOpc_match +m_InsertSubvector(const LHS &Base, const RHS &Sub, const IDX &Idx) { + return TernaryOpc_match(ISD::INSERT_SUBVECTOR, Base, Sub, Idx); +} + // === Binary operations === template @@ -802,6 +808,12 @@ inline BinaryOpc_match m_ExtractElt(const LHS &Vec, const RHS &Idx) { return BinaryOpc_match(ISD::EXTRACT_VECTOR_ELT, Vec, Idx); } +template +inline BinaryOpc_match m_ExtractSubvector(const LHS &Vec, + const RHS &Idx) { + return BinaryOpc_match(ISD::EXTRACT_SUBVECTOR, Vec, Idx); +} + // === Unary operations === template struct UnaryOpc_match { unsigned Opcode; diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp index 259bdad0ab272..31d7414d3e310 100644 --- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp +++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp @@ -138,9 +138,15 @@ TEST_F(SelectionDAGPatternMatchTest, matchTernaryOp) { SDValue Select = DAG->getSelect(DL, MVT::i1, Cond, T, F); auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4); + auto SmallVInt32VT = EVT::getVectorVT(Context, Int32VT, 2); + auto Idx0 = DAG->getVectorIdxConstant(0, DL); + auto Idx3 = DAG->getVectorIdxConstant(3, DL); SDValue V1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 6, VInt32VT); SDValue V2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 7, VInt32VT); + SDValue V3 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 8, SmallVInt32VT); SDValue VSelect = DAG->getNode(ISD::VSELECT, DL, VInt32VT, Cond, V1, V2); + SDValue InsertSubvector = + DAG->getNode(ISD::INSERT_SUBVECTOR, DL, VInt32VT, V2, V3, Idx0); SDValue ExtractELT = DAG->getNode(ISD::EXTRACT_VECTOR_ELT, DL, Int32VT, V1, Op3); @@ -182,15 +188,33 @@ TEST_F(SelectionDAGPatternMatchTest, matchTernaryOp) { EXPECT_TRUE(sd_match(ExtractELT, m_ExtractElt(m_Value(), m_Value()))); EXPECT_TRUE(sd_match(ExtractELT, m_ExtractElt(m_Value(), m_ConstInt()))); EXPECT_TRUE(sd_match(ExtractELT, m_ExtractElt(m_Value(), m_SpecificInt(1)))); + + EXPECT_TRUE(sd_match(InsertSubvector, + m_InsertSubvector(m_Value(), m_Value(), m_Value()))); + EXPECT_TRUE(sd_match( + InsertSubvector, + m_InsertSubvector(m_Specific(V2), m_Specific(V3), m_Specific(Idx0)))); + EXPECT_TRUE(sd_match( + InsertSubvector, + m_InsertSubvector(m_Specific(V2), m_Specific(V3), m_SpecificInt(0)))); + EXPECT_FALSE(sd_match( + InsertSubvector, + m_InsertSubvector(m_Specific(V2), m_Specific(V3), m_Specific(Idx3)))); + EXPECT_FALSE(sd_match( + InsertSubvector, + m_InsertSubvector(m_Specific(V2), m_Specific(V3), m_SpecificInt(3)))); } TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) { SDLoc DL; auto Int32VT = EVT::getIntegerVT(Context, 32); auto Float32VT = EVT::getFloatingPointVT(32); + auto BigVInt32VT = EVT::getVectorVT(Context, Int32VT, 8); auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4); SDValue V1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 6, VInt32VT); + auto Idx0 = DAG->getVectorIdxConstant(0, DL); + auto Idx1 = DAG->getVectorIdxConstant(1, DL); SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT); SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT); @@ -233,6 +257,10 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) { SDValue SFAdd = DAG->getNode(ISD::STRICT_FADD, DL, {Float32VT, MVT::Other}, {DAG->getEntryNode(), Op2, Op2}); + SDValue Vec = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 9, BigVInt32VT); + SDValue SubVec = + DAG->getNode(ISD::EXTRACT_SUBVECTOR, DL, VInt32VT, Vec, Idx0); + SDValue InsertELT = DAG->getNode(ISD::INSERT_VECTOR_ELT, DL, VInt32VT, V1, Op0, Op4); @@ -293,6 +321,16 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) { EXPECT_FALSE(sd_match(SFAdd, m_ChainedBinOp(ISD::STRICT_FADD, m_OtherVT(), m_SpecificVT(Float32VT)))); + EXPECT_TRUE(sd_match(SubVec, m_ExtractSubvector(m_Value(), m_Value()))); + EXPECT_TRUE( + sd_match(SubVec, m_ExtractSubvector(m_Specific(Vec), m_Specific(Idx0)))); + EXPECT_TRUE( + sd_match(SubVec, m_ExtractSubvector(m_Specific(Vec), m_SpecificInt(0)))); + EXPECT_FALSE( + sd_match(SubVec, m_ExtractSubvector(m_Specific(Vec), m_Specific(Idx1)))); + EXPECT_FALSE( + sd_match(SubVec, m_ExtractSubvector(m_Specific(Vec), m_SpecificInt(1)))); + EXPECT_TRUE( sd_match(InsertELT, m_InsertElt(m_Value(), m_Value(), m_Value()))); EXPECT_TRUE(