@@ -32,23 +32,11 @@ void ReadBinaryFile(const std::string& filename, std::string& contents) {
32
32
inputfs.close ();
33
33
}
34
34
35
- bool IsParameter (const framework::VarDesc* var,
36
- const framework::ProgramDesc& main_program) {
37
- if (var->Persistable ()) {
38
- // There are many unreachable variables in the program
39
- for (size_t i = 0 ; i < main_program.Size (); ++i) {
40
- const framework::BlockDesc& block = main_program.Block (i);
41
- for (auto * op : block.AllOps ()) {
42
- if (op->Type () == framework::kFeedOpType ) {
43
- continue ;
44
- }
45
- for (auto input_argument_name : op->InputArgumentNames ()) {
46
- if (input_argument_name == var->Name ()) {
47
- return true ;
48
- }
49
- }
50
- }
51
- }
35
+ bool IsPersistable (const framework::VarDesc* var) {
36
+ if (var->Persistable () &&
37
+ var->GetType () != framework::proto::VarDesc::FEED_MINIBATCH &&
38
+ var->GetType () != framework::proto::VarDesc::FETCH_LIST) {
39
+ return true ;
52
40
}
53
41
return false ;
54
42
}
@@ -65,8 +53,8 @@ void LoadPersistables(framework::Executor& executor,
65
53
std::vector<std::string> paramlist;
66
54
67
55
for (auto * var : global_block.AllVars ()) {
68
- if (IsParameter (var, main_program )) {
69
- VLOG (3 ) << " parameter 's name: " << var->Name ();
56
+ if (IsPersistable (var)) {
57
+ VLOG (3 ) << " persistable variable 's name: " << var->Name ();
70
58
71
59
framework::VarDesc* new_var = load_block->Var (var->Name ());
72
60
new_var->SetShape (var->GetShape ());
@@ -101,7 +89,6 @@ void LoadPersistables(framework::Executor& executor,
101
89
102
90
executor.Run (*load_program, &scope, 0 , true , true );
103
91
104
- VLOG (3 ) << " Ran loading successfully" ;
105
92
delete load_program;
106
93
}
107
94
0 commit comments