Skip to content

Commit b44917d

Browse files
committed
Implement IsPersistable() in c++.
1 parent 865dfbe commit b44917d

File tree

1 file changed

+7
-20
lines changed
  • paddle/fluid/inference

1 file changed

+7
-20
lines changed

paddle/fluid/inference/io.cc

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,11 @@ void ReadBinaryFile(const std::string& filename, std::string& contents) {
3232
inputfs.close();
3333
}
3434

35-
bool IsParameter(const framework::VarDesc* var,
36-
const framework::ProgramDesc& main_program) {
37-
if (var->Persistable()) {
38-
// There are many unreachable variables in the program
39-
for (size_t i = 0; i < main_program.Size(); ++i) {
40-
const framework::BlockDesc& block = main_program.Block(i);
41-
for (auto* op : block.AllOps()) {
42-
if (op->Type() == framework::kFeedOpType) {
43-
continue;
44-
}
45-
for (auto input_argument_name : op->InputArgumentNames()) {
46-
if (input_argument_name == var->Name()) {
47-
return true;
48-
}
49-
}
50-
}
51-
}
35+
bool IsPersistable(const framework::VarDesc* var) {
36+
if (var->Persistable() &&
37+
var->GetType() != framework::proto::VarDesc::FEED_MINIBATCH &&
38+
var->GetType() != framework::proto::VarDesc::FETCH_LIST) {
39+
return true;
5240
}
5341
return false;
5442
}
@@ -65,8 +53,8 @@ void LoadPersistables(framework::Executor& executor,
6553
std::vector<std::string> paramlist;
6654

6755
for (auto* var : global_block.AllVars()) {
68-
if (IsParameter(var, main_program)) {
69-
VLOG(3) << "parameter's name: " << var->Name();
56+
if (IsPersistable(var)) {
57+
VLOG(3) << "persistable variable's name: " << var->Name();
7058

7159
framework::VarDesc* new_var = load_block->Var(var->Name());
7260
new_var->SetShape(var->GetShape());
@@ -101,7 +89,6 @@ void LoadPersistables(framework::Executor& executor,
10189

10290
executor.Run(*load_program, &scope, 0, true, true);
10391

104-
VLOG(3) << "Ran loading successfully";
10592
delete load_program;
10693
}
10794

0 commit comments

Comments
 (0)