Skip to content

Commit 8e00540

Browse files
authored
Merge pull request #10058 from Xreki/core_fix_flush
Add flush of program desc to update the proto information.
2 parents cec4e6e + 7ffbcbc commit 8e00540

File tree

5 files changed

+13
-3
lines changed

5 files changed

+13
-3
lines changed

paddle/fluid/framework/block_desc.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ void BlockDesc::RemoveOp(size_t s, size_t e) {
146146
if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) {
147147
return;
148148
}
149+
need_update_ = true;
149150
ops_.erase(ops_.begin() + s, ops_.begin() + e);
150151
}
151152

paddle/fluid/framework/program_desc.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,14 @@ BlockDesc *ProgramDesc::AppendBlock(const BlockDesc &parent) {
2727
return blocks_.back().get();
2828
}
2929

30-
proto::ProgramDesc *ProgramDesc::Proto() {
30+
void ProgramDesc::Flush() {
3131
for (auto &block : blocks_) {
3232
block->Flush();
3333
}
34+
}
35+
36+
proto::ProgramDesc *ProgramDesc::Proto() {
37+
Flush();
3438
return &desc_;
3539
}
3640

paddle/fluid/framework/program_desc.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ class ProgramDesc {
5151

5252
size_t Size() const { return blocks_.size(); }
5353

54+
void Flush();
55+
5456
proto::ProgramDesc *Proto();
5557

5658
// The output variable of feed_op is referenced as feed_target.

paddle/fluid/pybind/protobuf.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ void BindProgramDesc(pybind11::module *m) {
127127
.def("block", &pd::ProgramDesc::MutableBlock,
128128
pybind11::return_value_policy::reference)
129129
.def("num_blocks", &pd::ProgramDesc::Size)
130+
.def("flush", &pd::ProgramDesc::Flush)
130131
.def("get_feed_target_names", &pd::ProgramDesc::GetFeedTargetNames)
131132
.def("get_fetch_target_names", &pd::ProgramDesc::GetFetchTargetNames)
132133
.def("serialize_to_string", SerializeMessage<pd::ProgramDesc>)

python/paddle/fluid/io.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,18 +336,20 @@ def save_inference_model(dirname,
336336

337337
if main_program is None:
338338
main_program = default_main_program()
339+
copy_program = main_program
339340

340341
if not os.path.isdir(dirname):
341342
os.makedirs(dirname)
342343

343344
# Clear the is_target information and remove the existed feed and fetch op
344-
global_block = main_program.global_block()
345+
global_block = copy_program.global_block()
345346
for i, op in enumerate(global_block.ops):
346347
op.desc.set_is_target(False)
347348
if op.type == "feed" or op.type == "fetch":
348349
global_block.remove_op(i)
350+
copy_program.desc.flush()
349351

350-
pruned_program = main_program.prune(targets=target_vars)
352+
pruned_program = copy_program.prune(targets=target_vars)
351353
inference_program = pruned_program.inference_optimize()
352354
fetch_var_names = [v.name for v in target_vars]
353355

0 commit comments

Comments
 (0)