Skip to content

Commit 3a81e03

Browse files
authored
[DAG] SDPatternMatch - Fix m_Reassociatable mismatching (#170061)
Fixes #169645 The issue was caused by a for-loop improperly overwriting SDValue binds when m_Reassociatable is given two or more patterns that (1) call m_Value with an SDValue parameter and (2) differ by that parameter. This fix comes with added unit tests relevant to SDValue bindings inside m_Reassociatable patterns. Essentially, the original implementation first tried to match every combination of leaf node and pattern possible and stored that in a matrix-like structure, and then did a recursive search on that matrix to check if it's possible to pair every leaf with a pattern. The problem is that m_Value has a side effect where it changes an SDValue, and the creation of this matrix was corrupting these values. Below is an example of this, following the order of execution in the original implementation and using the case brought by issue #169645, where this behavior was found. The example tries to match ((a >> 1) + (b >> 1) + (a & b & 1)), using uppercase letters for the SDValue variables themselves and lowercase for their values. The result is that the pattern matches the same value for A and B, which was the behavior observed in the issue: | Line | Leaf | Pattern | Match? | Effect | |--------|--------|--------|--------|--------| | 1 | a >> 1 | m_Srl(m_Value(A), m_One()) | Yes | A <- a | | 2 | a >> 1 | m_Srl(m_Value(B), m_One()) | Yes | B <- a | | 3 | a >> 1 | m_ReassociableAnd(m_Deferred(A), m_Deferred(B), m_One()) | No | -- | | 4 | b >> 1 | m_Srl(m_Value(A), m_One()) | Yes | A <- b | | 5 | b >> 1 | m_Srl(m_Value(B), m_One()) | Yes | B <- b | | 6 | b >> 1 | m_ReassociableAnd(m_Deferred(A), m_Deferred(B), m_One()) | No | -- | | 7 | a & b & 1 | m_Srl(m_Value(A), m_One()) | No | -- | | 8 | a & b & 1 | m_Srl(m_Value(B), m_One()) | No | -- | | 9 | a & b & 1 | m_ReassociableAnd(m_Deferred(A), m_Deferred(B), m_One()) | a == b | -- | To fix this, the function now matches the patterns during the recursive search itself, instead of preparing the matrix beforehand. Although this does fix the issue, it does mean that we're performing a best case of n and worst case of n! matching attempts, instead of the fixed nˆ2 in the original, where n is the number of patterns provided. Going back to the table above, using this fix the lines 2, 3, 4, 6, 7, and 8 do not happen, and so the only effects happening are A <- a and B <- b, which then will result in line 9 matching correctly.
1 parent 88273a0 commit 3a81e03

File tree

2 files changed

+52
-19
lines changed

2 files changed

+52
-19
lines changed

llvm/include/llvm/CodeGen/SDPatternMatch.h

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,19 +1315,12 @@ template <typename... PatternTs> struct ReassociatableOpc_match {
13151315
if (Leaves.size() != NumPatterns)
13161316
return false;
13171317

1318-
// Matches[I][J] == true iff sd_context_match(Leaves[I], Ctx,
1319-
// std::get<J>(Patterns)) == true
1320-
std::array<SmallBitVector, NumPatterns> Matches;
1321-
for (size_t I = 0; I != NumPatterns; I++) {
1322-
std::apply(
1323-
[&](auto &...P) {
1324-
(Matches[I].push_back(sd_context_match(Leaves[I], Ctx, P)), ...);
1325-
},
1326-
Patterns);
1327-
}
1328-
13291318
SmallBitVector Used(NumPatterns);
1330-
return reassociatableMatchHelper(Matches, Used);
1319+
return std::apply(
1320+
[&](auto &...P) -> bool {
1321+
return reassociatableMatchHelper(Ctx, Leaves, Used, P...);
1322+
},
1323+
Patterns);
13311324
}
13321325

13331326
void collectLeaves(SDValue V, SmallVector<SDValue> &Leaves) {
@@ -1339,21 +1332,29 @@ template <typename... PatternTs> struct ReassociatableOpc_match {
13391332
}
13401333
}
13411334

1335+
// Searchs for a matching leaf for every sub-pattern.
1336+
template <typename MatchContext, typename PatternHd, typename... PatternTl>
13421337
[[nodiscard]] inline bool
1343-
reassociatableMatchHelper(ArrayRef<SmallBitVector> Matches,
1344-
SmallBitVector &Used, size_t Curr = 0) {
1345-
if (Curr == Matches.size())
1346-
return true;
1347-
for (size_t Match = 0, N = Matches[Curr].size(); Match < N; Match++) {
1348-
if (!Matches[Curr][Match] || Used[Match])
1338+
reassociatableMatchHelper(const MatchContext &Ctx, ArrayRef<SDValue> Leaves,
1339+
SmallBitVector &Used, PatternHd &HeadPattern,
1340+
PatternTl &...TailPatterns) {
1341+
for (size_t Match = 0, N = Used.size(); Match < N; Match++) {
1342+
if (Used[Match] || !(sd_context_match(Leaves[Match], Ctx, HeadPattern)))
13491343
continue;
13501344
Used[Match] = true;
1351-
if (reassociatableMatchHelper(Matches, Used, Curr + 1))
1345+
if (reassociatableMatchHelper(Ctx, Leaves, Used, TailPatterns...))
13521346
return true;
13531347
Used[Match] = false;
13541348
}
13551349
return false;
13561350
}
1351+
1352+
template <typename MatchContext>
1353+
[[nodiscard]] inline bool reassociatableMatchHelper(const MatchContext &Ctx,
1354+
ArrayRef<SDValue> Leaves,
1355+
SmallBitVector &Used) {
1356+
return true;
1357+
}
13571358
};
13581359

13591360
template <typename... PatternTs>

llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,38 @@ TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
832832
EXPECT_FALSE(sd_match(ADDS0123, m_ReassociatableAdd(m_Value(), m_Value(),
833833
m_Value(), m_Value())));
834834

835+
// (Op0 + Op1) + Op0 binds correctly, allowing commutation on leaf nodes
836+
SDValue ADD010 = DAG->getNode(ISD::ADD, DL, Int32VT, ADD01, Op0);
837+
SDValue A, B;
838+
EXPECT_TRUE(sd_match(
839+
ADD010, m_ReassociatableAdd(m_Value(A), m_Value(B), m_Deferred(A))));
840+
EXPECT_EQ(Op0, A);
841+
EXPECT_EQ(Op1, B);
842+
843+
A.setNode(nullptr);
844+
B.setNode(nullptr);
845+
EXPECT_TRUE(sd_match(
846+
ADD010, m_ReassociatableAdd(m_Value(A), m_Value(B), m_Deferred(B))));
847+
EXPECT_EQ(Op0, B);
848+
EXPECT_EQ(Op1, A);
849+
850+
A.setNode(nullptr);
851+
B.setNode(nullptr);
852+
EXPECT_TRUE(sd_match(
853+
ADD010, m_ReassociatableAdd(m_Value(A), m_Deferred(A), m_Value(B))));
854+
EXPECT_EQ(Op0, A);
855+
EXPECT_EQ(Op1, B);
856+
857+
A.setNode(nullptr);
858+
B.setNode(nullptr);
859+
EXPECT_FALSE(sd_match(
860+
ADD010, m_ReassociatableAdd(m_Value(A), m_Deferred(A), m_Deferred(A))));
861+
862+
A.setNode(nullptr);
863+
B.setNode(nullptr);
864+
EXPECT_FALSE(sd_match(
865+
ADD010, m_ReassociatableAdd(m_Value(A), m_Deferred(B), m_Value(B))));
866+
835867
// (Op0 * Op1) * (Op2 * Op3)
836868
SDValue MUL01 = DAG->getNode(ISD::MUL, DL, Int32VT, Op0, Op1);
837869
SDValue MUL23 = DAG->getNode(ISD::MUL, DL, Int32VT, Op2, Op3);

0 commit comments

Comments
 (0)