Skip to content

Commit 65058cf

Browse files
committed
Change DFS to BFS
1 parent 14f8370 commit 65058cf

File tree

2 files changed

+51
-33
lines changed

2 files changed

+51
-33
lines changed

paddle/fluid/framework/block_desc.cc

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ limitations under the License. */
1616
#include "paddle/fluid/framework/operator.h"
1717
#include "paddle/fluid/framework/program_desc.h"
1818

19+
#include <queue>
20+
1921
namespace paddle {
2022
namespace framework {
2123

@@ -45,23 +47,33 @@ bool BlockDesc::HasVar(const std::string &name) const {
4547
VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const {
4648
if (name == kEmptyVarName) return nullptr;
4749

48-
auto it = vars_.find(name);
49-
if (it != vars_.end()) {
50-
return it->second.get();
51-
}
50+
std::queue<const BlockDesc *> frontier;
51+
std::unordered_set<const BlockDesc *> visited;
5252

53-
BlockDesc *tmp = ParentBlock();
53+
frontier.push(this);
5454

55-
if (tmp != nullptr) {
56-
auto ptr = tmp->FindVarRecursive(name);
57-
if (ptr != nullptr) {
58-
return ptr;
55+
while (!frontier.empty()) { // BFS
56+
auto cur = frontier.front();
57+
frontier.pop();
58+
if (visited.count(cur) != 0) {
59+
continue;
60+
}
61+
auto var = cur->FindVar(name);
62+
if (var != nullptr) {
63+
return var;
64+
}
65+
66+
auto fwd = cur->ForwardBlock();
67+
auto parent = cur->ParentBlock();
68+
69+
if (fwd != nullptr) {
70+
frontier.push(fwd);
71+
}
72+
if (parent != nullptr) {
73+
frontier.push(parent);
5974
}
60-
}
6175

62-
tmp = ForwardBlock();
63-
if (tmp != nullptr) {
64-
return tmp->FindVarRecursive(name);
76+
visited.insert(cur);
6577
}
6678

6779
return nullptr;

python/paddle/v2/fluid/framework.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -698,26 +698,32 @@ def var(self, name):
698698
return v
699699

700700
def var_recursive(self, name):
701-
if self.has_var(name):
702-
return self.var(name)
703-
else:
704-
if self.idx == 0:
705-
raise ValueError(
706-
"var {0} is not in block({1}) nor its parents.".format(
707-
name, self.idx))
708-
else:
709-
# DFS
710-
try:
711-
parent_block = self.program.block(self.parent_idx)
712-
return parent_block.var_recursive(name)
713-
except ValueError:
714-
fwd_block = self.program.block(
715-
self.forward_block_idx
716-
) if self.forward_block_idx != -1 else None
717-
if fwd_block is not None:
718-
return fwd_block.var_recursive(name)
719-
else:
720-
raise
701+
frontier = list()
702+
visited = set()
703+
704+
frontier.append(self)
705+
706+
prog = self.program
707+
708+
while len(frontier) != 0: # BFS
709+
cur = frontier[0]
710+
frontier = frontier[1:]
711+
712+
if id(cur) in visited:
713+
continue
714+
715+
if cur.has_var(name):
716+
return cur.var(name)
717+
718+
if cur.parent_idx != -1:
719+
frontier.append(prog.block(cur.parent_idx))
720+
721+
if cur.forward_block_idx != -1:
722+
frontier.append(prog.block(cur.forward_block_idx))
723+
724+
visited.add(id(cur))
725+
726+
raise ValueError("Var {0} is not found recursively".format(name))
721727

722728
def all_parameters(self):
723729
return list(self.iter_parameters())

0 commit comments

Comments
 (0)