@@ -167,10 +167,12 @@ struct HitGroup {
167
167
168
168
bool Match (Node *node, PDNode *pat) {
169
169
if (nodes_.count (node)) {
170
- if (!roles.count (pat)) return false ;
171
- return roles[pat] == node;
170
+ if (roles.count (pat) && roles[pat] == node) return true ;
171
+ return false ;
172
+ } else {
173
+ if (roles.count (pat) && roles[pat] != node) return false ;
174
+ return true ;
172
175
}
173
- return !roles.count (pat) || roles.at (pat) == node;
174
176
}
175
177
176
178
void Register (Node *node, PDNode *pat) {
@@ -198,7 +200,6 @@ GraphPatternDetector::DetectPatterns() {
198
200
std::vector<GraphPatternDetector::subgraph_t > result;
199
201
std::vector<HitGroup> init_groups;
200
202
std::array<std::vector<HitGroup>, 2 > bi_records;
201
- // PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed");
202
203
auto *first_pnode = pattern_.edges ().empty () ? pattern ().nodes ().front ().get ()
203
204
: pattern_.edges ().front ().first ;
204
205
if (!pdnodes2nodes_.count (first_pnode)) return result;
@@ -228,11 +229,12 @@ GraphPatternDetector::DetectPatterns() {
228
229
VLOG (80 ) << " check " << source->id () << " -- " << target->id ();
229
230
// TODO(Superjomn) add some prune strategies.
230
231
for (const auto &group : pre_groups) {
231
- HitGroup new_group = group;
232
- if (IsNodesLink (source, target) &&
233
- new_group.Match (source, edge.first )) {
234
- new_group.Register (source, edge.first );
235
- if (new_group.Match (target, edge.second )) {
232
+ if (IsNodesLink (source, target)) {
233
+ HitGroup new_group = group;
234
+ bool flag = new_group.Match (source, edge.first ) &&
235
+ new_group.Match (target, edge.second );
236
+ if (flag) {
237
+ new_group.Register (source, edge.first );
236
238
new_group.Register (target, edge.second );
237
239
cur_groups.push_back (new_group);
238
240
// TODO(Superjomn) need to unique
0 commit comments