Skip to content

Commit 8d928e4

Browse files
committed
add reassociatable matchers
1 parent 8b02d80 commit 8d928e4

File tree

2 files changed

+176
-0
lines changed

2 files changed

+176
-0
lines changed

llvm/include/llvm/CodeGen/SDPatternMatch.h

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,6 +1072,97 @@ inline BinaryOpc_match<ValTy, AllOnes_match, true> m_Not(const ValTy &V) {
10721072
return m_Xor(V, m_AllOnes());
10731073
}
10741074

1075+
template <typename... PatternTs> struct ReassociatableOpc_match {
1076+
unsigned Opcode;
1077+
std::tuple<PatternTs...> Patterns;
1078+
1079+
ReassociatableOpc_match(unsigned Opcode, const PatternTs &...Patterns)
1080+
: Opcode(Opcode), Patterns(Patterns...) {}
1081+
1082+
template <typename MatchContext>
1083+
bool match(const MatchContext &Ctx, SDValue N) {
1084+
SmallVector<SDValue> Leaves;
1085+
collectLeaves(N, Leaves);
1086+
if (Leaves.size() != std::tuple_size_v<std::tuple<PatternTs...>>) {
1087+
return false;
1088+
}
1089+
1090+
// J in Matches[I] iff sd_context_match(Leaves[I], Ctx,
1091+
// std::get<J>(Patterns)) == true
1092+
SmallVector<SmallVector<size_t>> Matches(Leaves.size());
1093+
for (size_t I = 0; I < Leaves.size(); I += 1) {
1094+
SmallVector<bool> MatchResults;
1095+
std::apply(
1096+
[&](auto &...P) {
1097+
(MatchResults.emplace_back(sd_context_match(Leaves[I], Ctx, P)),
1098+
...);
1099+
},
1100+
Patterns);
1101+
for (size_t J = 0; J < MatchResults.size(); J += 1) {
1102+
if (MatchResults[J]) {
1103+
Matches[I].emplace_back(J);
1104+
}
1105+
}
1106+
}
1107+
1108+
SmallVector<bool> Used(std::tuple_size_v<std::tuple<PatternTs...>>, false);
1109+
return reassociatableMatchHelper(Matches, Used);
1110+
}
1111+
1112+
void collectLeaves(SDValue V, SmallVector<SDValue> &Leaves) {
1113+
if (V->getOpcode() == Opcode) {
1114+
for (size_t I = 0; I < V->getNumOperands(); I += 1) {
1115+
collectLeaves(V->getOperand(I), Leaves);
1116+
}
1117+
} else {
1118+
Leaves.emplace_back(V);
1119+
}
1120+
}
1121+
1122+
[[nodiscard]] inline bool
1123+
reassociatableMatchHelper(const SmallVector<SmallVector<size_t>> &Matches,
1124+
SmallVector<bool> &Used, size_t Curr = 0) {
1125+
if (Curr == Matches.size()) {
1126+
return true;
1127+
}
1128+
for (auto Match : Matches[Curr]) {
1129+
if (Used[Match]) {
1130+
continue;
1131+
}
1132+
Used[Match] = true;
1133+
if (reassociatableMatchHelper(Matches, Used, Curr + 1)) {
1134+
return true;
1135+
}
1136+
Used[Match] = false;
1137+
}
1138+
return false;
1139+
}
1140+
};
1141+
1142+
template <typename... PatternTs>
1143+
inline ReassociatableOpc_match<PatternTs...>
1144+
m_ReassociatableAdd(const PatternTs &...Patterns) {
1145+
return ReassociatableOpc_match<PatternTs...>(ISD::ADD, Patterns...);
1146+
}
1147+
1148+
template <typename... PatternTs>
1149+
inline ReassociatableOpc_match<PatternTs...>
1150+
m_ReassociatableOr(const PatternTs &...Patterns) {
1151+
return ReassociatableOpc_match<PatternTs...>(ISD::OR, Patterns...);
1152+
}
1153+
1154+
template <typename... PatternTs>
1155+
inline ReassociatableOpc_match<PatternTs...>
1156+
m_ReassociatableAnd(const PatternTs &...Patterns) {
1157+
return ReassociatableOpc_match<PatternTs...>(ISD::AND, Patterns...);
1158+
}
1159+
1160+
template <typename... PatternTs>
1161+
inline ReassociatableOpc_match<PatternTs...>
1162+
m_ReassociatableMul(const PatternTs &...Patterns) {
1163+
return ReassociatableOpc_match<PatternTs...>(ISD::MUL, Patterns...);
1164+
}
1165+
10751166
} // namespace SDPatternMatch
10761167
} // namespace llvm
10771168
#endif

llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,3 +576,88 @@ TEST_F(SelectionDAGPatternMatchTest, matchAdvancedProperties) {
576576
EXPECT_TRUE(sd_match(Add, DAG.get(),
577577
m_LegalOp(m_IntegerVT(m_Add(m_Value(), m_Value())))));
578578
}
579+
580+
TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
581+
using namespace SDPatternMatch;
582+
583+
SDLoc DL;
584+
auto Int32VT = EVT::getIntegerVT(Context, 32);
585+
auto Float32VT = EVT::getFloatingPointVT(32);
586+
587+
SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
588+
SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
589+
SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Float32VT);
590+
SDValue Op3 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 8, Int32VT);
591+
592+
// (Op0 + Op1) + (Op2 + Op3)
593+
SDValue ADD01 = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1);
594+
SDValue ADD23 = DAG->getNode(ISD::ADD, DL, Int32VT, Op2, Op3);
595+
SDValue ADD = DAG->getNode(ISD::ADD, DL, Int32VT, ADD01, ADD23);
596+
597+
EXPECT_TRUE(sd_match(ADD01, m_ReassociatableAdd(m_Value(), m_Value())));
598+
EXPECT_TRUE(sd_match(ADD23, m_ReassociatableAdd(m_Value(), m_Value())));
599+
EXPECT_TRUE(sd_match(
600+
ADD, m_ReassociatableAdd(m_Value(), m_Value(), m_Value(), m_Value())));
601+
602+
// Op0 + (Op1 + (Op2 + Op3))
603+
SDValue ADD123 = DAG->getNode(ISD::ADD, DL, Int32VT, Op2, ADD23);
604+
SDValue ADD0123 = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, ADD123);
605+
EXPECT_TRUE(
606+
sd_match(ADD123, m_ReassociatableAdd(m_Value(), m_Value(), m_Value())));
607+
EXPECT_TRUE(sd_match(ADD0123, m_ReassociatableAdd(m_Value(), m_Value(),
608+
m_Value(), m_Value())));
609+
610+
// (Op0 * Op1) * (Op2 * Op3)
611+
SDValue MUL01 = DAG->getNode(ISD::MUL, DL, Int32VT, Op0, Op1);
612+
SDValue MUL23 = DAG->getNode(ISD::MUL, DL, Int32VT, Op2, Op3);
613+
SDValue MUL = DAG->getNode(ISD::MUL, DL, Int32VT, MUL01, MUL23);
614+
615+
EXPECT_TRUE(sd_match(MUL01, m_ReassociatableMul(m_Value(), m_Value())));
616+
EXPECT_TRUE(sd_match(MUL23, m_ReassociatableMul(m_Value(), m_Value())));
617+
EXPECT_TRUE(sd_match(
618+
MUL, m_ReassociatableMul(m_Value(), m_Value(), m_Value(), m_Value())));
619+
620+
// Op0 * (Op1 * (Op2 * Op3))
621+
SDValue MUL123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op2, MUL23);
622+
SDValue MUL0123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op0, MUL123);
623+
EXPECT_TRUE(
624+
sd_match(MUL123, m_ReassociatableMul(m_Value(), m_Value(), m_Value())));
625+
EXPECT_TRUE(sd_match(MUL0123, m_ReassociatableMul(m_Value(), m_Value(),
626+
m_Value(), m_Value())));
627+
628+
// (Op0 && Op1) && (Op2 && Op3)
629+
SDValue AND01 = DAG->getNode(ISD::AND, DL, Int32VT, Op0, Op1);
630+
SDValue AND23 = DAG->getNode(ISD::AND, DL, Int32VT, Op2, Op3);
631+
SDValue AND = DAG->getNode(ISD::AND, DL, Int32VT, AND01, AND23);
632+
633+
EXPECT_TRUE(sd_match(AND01, m_ReassociatableAnd(m_Value(), m_Value())));
634+
EXPECT_TRUE(sd_match(AND23, m_ReassociatableAnd(m_Value(), m_Value())));
635+
EXPECT_TRUE(sd_match(
636+
AND, m_ReassociatableAnd(m_Value(), m_Value(), m_Value(), m_Value())));
637+
638+
// Op0 && (Op1 && (Op2 && Op3))
639+
SDValue AND123 = DAG->getNode(ISD::AND, DL, Int32VT, Op2, AND23);
640+
SDValue AND0123 = DAG->getNode(ISD::AND, DL, Int32VT, Op0, AND123);
641+
EXPECT_TRUE(
642+
sd_match(AND123, m_ReassociatableAnd(m_Value(), m_Value(), m_Value())));
643+
EXPECT_TRUE(sd_match(AND0123, m_ReassociatableAnd(m_Value(), m_Value(),
644+
m_Value(), m_Value())));
645+
646+
// (Op0 || Op1) || (Op2 || Op3)
647+
SDValue OR01 = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op1);
648+
SDValue OR23 = DAG->getNode(ISD::OR, DL, Int32VT, Op2, Op3);
649+
SDValue OR = DAG->getNode(ISD::OR, DL, Int32VT, OR01, OR23);
650+
651+
EXPECT_TRUE(sd_match(OR01, m_ReassociatableOr(m_Value(), m_Value())));
652+
EXPECT_TRUE(sd_match(OR23, m_ReassociatableOr(m_Value(), m_Value())));
653+
EXPECT_TRUE(sd_match(
654+
OR, m_ReassociatableOr(m_Value(), m_Value(), m_Value(), m_Value())));
655+
656+
// Op0 || (Op1 || (Op2 || Op3))
657+
SDValue OR123 = DAG->getNode(ISD::OR, DL, Int32VT, Op2, OR23);
658+
SDValue OR0123 = DAG->getNode(ISD::OR, DL, Int32VT, Op0, OR123);
659+
EXPECT_TRUE(
660+
sd_match(OR123, m_ReassociatableOr(m_Value(), m_Value(), m_Value())));
661+
EXPECT_TRUE(sd_match(
662+
OR0123, m_ReassociatableOr(m_Value(), m_Value(), m_Value(), m_Value())));
663+
}

0 commit comments

Comments
 (0)