diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h index daafd3fc9d825..dda3b3827c7aa 100644 --- a/llvm/include/llvm/CodeGen/SDPatternMatch.h +++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h @@ -1315,19 +1315,12 @@ template struct ReassociatableOpc_match { if (Leaves.size() != NumPatterns) return false; - // Matches[I][J] == true iff sd_context_match(Leaves[I], Ctx, - // std::get(Patterns)) == true - std::array Matches; - for (size_t I = 0; I != NumPatterns; I++) { - std::apply( - [&](auto &...P) { - (Matches[I].push_back(sd_context_match(Leaves[I], Ctx, P)), ...); - }, - Patterns); - } - SmallBitVector Used(NumPatterns); - return reassociatableMatchHelper(Matches, Used); + return std::apply( + [&](auto &...P) -> bool { + return reassociatableMatchHelper(Ctx, Leaves, Used, P...); + }, + Patterns); } void collectLeaves(SDValue V, SmallVector &Leaves) { @@ -1339,21 +1332,29 @@ template struct ReassociatableOpc_match { } } + // Searchs for a matching leaf for every sub-pattern. + template [[nodiscard]] inline bool - reassociatableMatchHelper(ArrayRef Matches, - SmallBitVector &Used, size_t Curr = 0) { - if (Curr == Matches.size()) - return true; - for (size_t Match = 0, N = Matches[Curr].size(); Match < N; Match++) { - if (!Matches[Curr][Match] || Used[Match]) + reassociatableMatchHelper(const MatchContext &Ctx, ArrayRef Leaves, + SmallBitVector &Used, PatternHd &HeadPattern, + PatternTl &...TailPatterns) { + for (size_t Match = 0, N = Used.size(); Match < N; Match++) { + if (Used[Match] || !(sd_context_match(Leaves[Match], Ctx, HeadPattern))) continue; Used[Match] = true; - if (reassociatableMatchHelper(Matches, Used, Curr + 1)) + if (reassociatableMatchHelper(Ctx, Leaves, Used, TailPatterns...)) return true; Used[Match] = false; } return false; } + + template + [[nodiscard]] inline bool reassociatableMatchHelper(const MatchContext &Ctx, + ArrayRef Leaves, + SmallBitVector &Used) { + return true; + } }; template diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp index c32ceee73472d..4fcd3fcb8c5c7 100644 --- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp +++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp @@ -832,6 +832,38 @@ TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) { EXPECT_FALSE(sd_match(ADDS0123, m_ReassociatableAdd(m_Value(), m_Value(), m_Value(), m_Value()))); + // (Op0 + Op1) + Op0 binds correctly, allowing commutation on leaf nodes + SDValue ADD010 = DAG->getNode(ISD::ADD, DL, Int32VT, ADD01, Op0); + SDValue A, B; + EXPECT_TRUE(sd_match( + ADD010, m_ReassociatableAdd(m_Value(A), m_Value(B), m_Deferred(A)))); + EXPECT_EQ(Op0, A); + EXPECT_EQ(Op1, B); + + A.setNode(nullptr); + B.setNode(nullptr); + EXPECT_TRUE(sd_match( + ADD010, m_ReassociatableAdd(m_Value(A), m_Value(B), m_Deferred(B)))); + EXPECT_EQ(Op0, B); + EXPECT_EQ(Op1, A); + + A.setNode(nullptr); + B.setNode(nullptr); + EXPECT_TRUE(sd_match( + ADD010, m_ReassociatableAdd(m_Value(A), m_Deferred(A), m_Value(B)))); + EXPECT_EQ(Op0, A); + EXPECT_EQ(Op1, B); + + A.setNode(nullptr); + B.setNode(nullptr); + EXPECT_FALSE(sd_match( + ADD010, m_ReassociatableAdd(m_Value(A), m_Deferred(A), m_Deferred(A)))); + + A.setNode(nullptr); + B.setNode(nullptr); + EXPECT_FALSE(sd_match( + ADD010, m_ReassociatableAdd(m_Value(A), m_Deferred(B), m_Value(B)))); + // (Op0 * Op1) * (Op2 * Op3) SDValue MUL01 = DAG->getNode(ISD::MUL, DL, Int32VT, Op0, Op1); SDValue MUL23 = DAG->getNode(ISD::MUL, DL, Int32VT, Op2, Op3);