File tree Expand file tree Collapse file tree 2 files changed +51
-33
lines changed Expand file tree Collapse file tree 2 files changed +51
-33
lines changed Original file line number Diff line number Diff line change @@ -16,6 +16,8 @@ limitations under the License. */
16
16
#include " paddle/fluid/framework/operator.h"
17
17
#include " paddle/fluid/framework/program_desc.h"
18
18
19
+ #include < queue>
20
+
19
21
namespace paddle {
20
22
namespace framework {
21
23
@@ -45,23 +47,33 @@ bool BlockDesc::HasVar(const std::string &name) const {
45
47
VarDesc *BlockDesc::FindVarRecursive (const std::string &name) const {
46
48
if (name == kEmptyVarName ) return nullptr ;
47
49
48
- auto it = vars_.find (name);
49
- if (it != vars_.end ()) {
50
- return it->second .get ();
51
- }
50
+ std::queue<const BlockDesc *> frontier;
51
+ std::unordered_set<const BlockDesc *> visited;
52
52
53
- BlockDesc *tmp = ParentBlock ( );
53
+ frontier. push ( this );
54
54
55
- if (tmp != nullptr ) {
56
- auto ptr = tmp->FindVarRecursive (name);
57
- if (ptr != nullptr ) {
58
- return ptr;
55
+ while (!frontier.empty ()) { // BFS
56
+ auto cur = frontier.front ();
57
+ frontier.pop ();
58
+ if (visited.count (cur) != 0 ) {
59
+ continue ;
60
+ }
61
+ auto var = cur->FindVar (name);
62
+ if (var != nullptr ) {
63
+ return var;
64
+ }
65
+
66
+ auto fwd = cur->ForwardBlock ();
67
+ auto parent = cur->ParentBlock ();
68
+
69
+ if (fwd != nullptr ) {
70
+ frontier.push (fwd);
71
+ }
72
+ if (parent != nullptr ) {
73
+ frontier.push (parent);
59
74
}
60
- }
61
75
62
- tmp = ForwardBlock ();
63
- if (tmp != nullptr ) {
64
- return tmp->FindVarRecursive (name);
76
+ visited.insert (cur);
65
77
}
66
78
67
79
return nullptr ;
Original file line number Diff line number Diff line change @@ -698,26 +698,32 @@ def var(self, name):
698
698
return v
699
699
700
700
def var_recursive (self , name ):
701
- if self .has_var (name ):
702
- return self .var (name )
703
- else :
704
- if self .idx == 0 :
705
- raise ValueError (
706
- "var {0} is not in block({1}) nor its parents." .format (
707
- name , self .idx ))
708
- else :
709
- # DFS
710
- try :
711
- parent_block = self .program .block (self .parent_idx )
712
- return parent_block .var_recursive (name )
713
- except ValueError :
714
- fwd_block = self .program .block (
715
- self .forward_block_idx
716
- ) if self .forward_block_idx != - 1 else None
717
- if fwd_block is not None :
718
- return fwd_block .var_recursive (name )
719
- else :
720
- raise
701
+ frontier = list ()
702
+ visited = set ()
703
+
704
+ frontier .append (self )
705
+
706
+ prog = self .program
707
+
708
+ while len (frontier ) != 0 : # BFS
709
+ cur = frontier [0 ]
710
+ frontier = frontier [1 :]
711
+
712
+ if id (cur ) in visited :
713
+ continue
714
+
715
+ if cur .has_var (name ):
716
+ return cur .var (name )
717
+
718
+ if cur .parent_idx != - 1 :
719
+ frontier .append (prog .block (cur .parent_idx ))
720
+
721
+ if cur .forward_block_idx != - 1 :
722
+ frontier .append (prog .block (cur .forward_block_idx ))
723
+
724
+ visited .add (id (cur ))
725
+
726
+ raise ValueError ("Var {0} is not found recursively" .format (name ))
721
727
722
728
def all_parameters (self ):
723
729
return list (self .iter_parameters ())
You can’t perform that action at this time.
0 commit comments