Skip to content

Commit 6c32945

Browse files
authored
Merge pull request #14372 from luotao1/speedup_analysis
speedup DetectPatterns
2 parents 4a55fb5 + 668ae52 commit 6c32945

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,12 @@ struct HitGroup {
167167

168168
bool Match(Node *node, PDNode *pat) {
169169
if (nodes_.count(node)) {
170-
if (!roles.count(pat)) return false;
171-
return roles[pat] == node;
170+
if (roles.count(pat) && roles[pat] == node) return true;
171+
return false;
172+
} else {
173+
if (roles.count(pat) && roles[pat] != node) return false;
174+
return true;
172175
}
173-
return !roles.count(pat) || roles.at(pat) == node;
174176
}
175177

176178
void Register(Node *node, PDNode *pat) {
@@ -198,7 +200,6 @@ GraphPatternDetector::DetectPatterns() {
198200
std::vector<GraphPatternDetector::subgraph_t> result;
199201
std::vector<HitGroup> init_groups;
200202
std::array<std::vector<HitGroup>, 2> bi_records;
201-
// PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed");
202203
auto *first_pnode = pattern_.edges().empty() ? pattern().nodes().front().get()
203204
: pattern_.edges().front().first;
204205
if (!pdnodes2nodes_.count(first_pnode)) return result;
@@ -228,11 +229,12 @@ GraphPatternDetector::DetectPatterns() {
228229
VLOG(80) << "check " << source->id() << " -- " << target->id();
229230
// TODO(Superjomn) add some prune strategies.
230231
for (const auto &group : pre_groups) {
231-
HitGroup new_group = group;
232-
if (IsNodesLink(source, target) &&
233-
new_group.Match(source, edge.first)) {
234-
new_group.Register(source, edge.first);
235-
if (new_group.Match(target, edge.second)) {
232+
if (IsNodesLink(source, target)) {
233+
HitGroup new_group = group;
234+
bool flag = new_group.Match(source, edge.first) &&
235+
new_group.Match(target, edge.second);
236+
if (flag) {
237+
new_group.Register(source, edge.first);
236238
new_group.Register(target, edge.second);
237239
cur_groups.push_back(new_group);
238240
// TODO(Superjomn) need to unique

0 commit comments

Comments
 (0)