Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 22 additions & 19 deletions llvm/include/llvm/CodeGen/SDPatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -1315,19 +1315,12 @@ template <typename... PatternTs> struct ReassociatableOpc_match {
if (Leaves.size() != NumPatterns)
return false;

// Matches[I][J] == true iff sd_context_match(Leaves[I], Ctx,
// std::get<J>(Patterns)) == true
std::array<SmallBitVector, NumPatterns> Matches;
for (size_t I = 0; I != NumPatterns; I++) {
std::apply(
[&](auto &...P) {
(Matches[I].push_back(sd_context_match(Leaves[I], Ctx, P)), ...);
},
Patterns);
}

SmallBitVector Used(NumPatterns);
return reassociatableMatchHelper(Matches, Used);
return std::apply(
[&](auto &...P) -> bool {
return reassociatableMatchHelper(Ctx, Leaves, Used, P...);
},
Patterns);
}

void collectLeaves(SDValue V, SmallVector<SDValue> &Leaves) {
Expand All @@ -1339,21 +1332,31 @@ template <typename... PatternTs> struct ReassociatableOpc_match {
}
}

// Searchs for a matching leaf for every sub-pattern.
template <typename MatchContext, typename PatternHd, typename... PatternTl>
[[nodiscard]] inline bool
reassociatableMatchHelper(ArrayRef<SmallBitVector> Matches,
SmallBitVector &Used, size_t Curr = 0) {
if (Curr == Matches.size())
return true;
for (size_t Match = 0, N = Matches[Curr].size(); Match < N; Match++) {
if (!Matches[Curr][Match] || Used[Match])
reassociatableMatchHelper(const MatchContext &Ctx,
SmallVector<SDValue> &Leaves, SmallBitVector &Used,
PatternHd &HeadPattern,
PatternTl &...TailPatterns) {
for (size_t Match = 0, N = Used.size(); Match < N; Match++) {
if (Used[Match] || !(sd_context_match(Leaves[Match], Ctx, HeadPattern)))
continue;
Used[Match] = true;
if (reassociatableMatchHelper(Matches, Used, Curr + 1))
if (reassociatableMatchHelper(Ctx, Leaves, Used, TailPatterns...))
return true;
Used[Match] = false;
}
return false;
}

template <typename MatchContext>
[[nodiscard]] inline bool
reassociatableMatchHelper(const MatchContext &Ctx,
SmallVector<SDValue> &Leaves,
SmallBitVector &Used) {
return true;
}
};

template <typename... PatternTs>
Expand Down
10 changes: 10 additions & 0 deletions llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,16 @@ TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
EXPECT_TRUE(sd_match(
MUL, m_ReassociatableMul(m_Value(), m_Value(), m_Value(), m_Value())));

// (Op0 + Op1) + Op0 binds correctly, allowing commutation
SDValue ADD010 = DAG->getNode(ISD::ADD, DL, Int32VT, ADD01, Op0);
SDValue A, B;
EXPECT_TRUE(sd_match(
ADD010, m_ReassociatableAdd(m_Value(A), m_Value(B), m_Deferred(A))));
EXPECT_TRUE(sd_match(
ADD010, m_ReassociatableAdd(m_Value(A), m_Value(B), m_Deferred(B))));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem I was often seeing wasn't that m_ReassociatableAdd etc. didn't return true - it was that the values of A + B weren't correct - please can you add additional checks that A and B are correctly initialized to the correct ADD01/Op0 pairs?

Copy link
Contributor Author

@bermondd bermondd Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, alright. Just added tests for this in the newest commit.
Edit: second to newest now. Had to do a fixup to use clang-format.

EXPECT_FALSE(sd_match(
ADD010, m_ReassociatableAdd(m_Value(A), m_Deferred(A), m_Deferred(A))));

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add some negative test cases? That is, showing that this pattern does not match when it should not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added a few more negative tests in response to another review, but they were all related to the binds done on the 2 positive test cases, e.g. to check m_Deferred(A) does not match with Op1 in the first case. I don't know if that meets what you asked for, though. Do you want me to add completely new negative test cases, alongside the current trivial one?

// Op0 * (Op1 * (Op2 * Op3))
SDValue MUL123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op1, MUL23);
SDValue MUL0123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op0, MUL123);
Expand Down