Skip to content

Commit 598035f

Browse files
Xrekikexinzhao
authored andcommitted
Fix a bug in save_inference_model and prune when the program is initailized by load_inference_model (#10011)
* Fix bug in save_inference_model and prune when the program is initialized by load_inference_program. * Save the transpiled program instead.
1 parent 9ca578d commit 598035f

File tree

6 files changed

+24
-22
lines changed

6 files changed

+24
-22
lines changed

paddle/fluid/framework/op_desc.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class OpDesc {
119119

120120
void InferVarType(BlockDesc *block) const;
121121

122-
void MarkAsTarget() { desc_.set_is_target(true); }
122+
void SetIsTarget(bool is_target) { desc_.set_is_target(is_target); }
123123

124124
void Flush();
125125

paddle/fluid/pybind/protobuf.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ void BindProgramDesc(pybind11::module *m) {
127127
.def("block", &pd::ProgramDesc::MutableBlock,
128128
pybind11::return_value_policy::reference)
129129
.def("num_blocks", &pd::ProgramDesc::Size)
130+
.def("get_feed_target_names", &pd::ProgramDesc::GetFeedTargetNames)
131+
.def("get_fetch_target_names", &pd::ProgramDesc::GetFetchTargetNames)
130132
.def("serialize_to_string", SerializeMessage<pd::ProgramDesc>)
131133
.def("parse_from_string",
132134
[](pd::ProgramDesc &program_desc, const std::string &data) {
@@ -299,6 +301,7 @@ void BindOpDesc(pybind11::module *m) {
299301
.def("check_attrs", &pd::OpDesc::CheckAttrs)
300302
.def("infer_shape", &pd::OpDesc::InferShape)
301303
.def("infer_var_type", &pd::OpDesc::InferVarType)
304+
.def("set_is_target", &pd::OpDesc::SetIsTarget)
302305
.def("serialize_to_string", SerializeMessage<pd::OpDesc>)
303306
.def("block", &pd::OpDesc::Block,
304307
pybind11::return_value_policy::reference);

paddle/fluid/pybind/pybind.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ All parameter, weight, gradient are variables in Paddle.
294294
const std::vector<std::array<size_t, 2>> &targets) {
295295
ProgramDesc prog_with_targets(origin);
296296
for (const auto &t : targets) {
297-
prog_with_targets.MutableBlock(t[0])->Op(t[1])->MarkAsTarget();
297+
prog_with_targets.MutableBlock(t[0])->Op(t[1])->SetIsTarget(true);
298298
}
299299
proto::ProgramDesc pruned_desc;
300300
Prune(*prog_with_targets.Proto(), &pruned_desc);

python/paddle/fluid/framework.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,12 @@ def prune(self, targets):
10701070
for t in targets:
10711071
if not isinstance(t, Operator):
10721072
if isinstance(t, Variable):
1073+
if t.op is None:
1074+
global_block = self.global_block()
1075+
for op in global_block.ops:
1076+
if t.name in op.output_arg_names:
1077+
t.op = op
1078+
break
10731079
t = t.op
10741080
else:
10751081
raise ValueError(("All targets of prune() can only be "

python/paddle/fluid/io.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,13 @@ def save_inference_model(dirname,
340340
if not os.path.isdir(dirname):
341341
os.makedirs(dirname)
342342

343+
# Clear the is_target information and remove the existed feed and fetch op
344+
global_block = main_program.global_block()
345+
for i, op in enumerate(global_block.ops):
346+
op.desc.set_is_target(False)
347+
if op.type == "feed" or op.type == "fetch":
348+
global_block.remove_op(i)
349+
343350
pruned_program = main_program.prune(targets=target_vars)
344351
inference_program = pruned_program.inference_optimize()
345352
fetch_var_names = [v.name for v in target_vars]
@@ -362,24 +369,6 @@ def save_inference_model(dirname,
362369
save_persistables(executor, dirname, inference_program, params_filename)
363370

364371

365-
def get_feed_targets_names(program):
366-
feed_targets_names = []
367-
global_block = program.global_block()
368-
for op in global_block.ops:
369-
if op.desc.type() == 'feed':
370-
feed_targets_names.insert(0, op.desc.output('Out')[0])
371-
return feed_targets_names
372-
373-
374-
def get_fetch_targets_names(program):
375-
fetch_targets_names = []
376-
global_block = program.global_block()
377-
for op in global_block.ops:
378-
if op.desc.type() == 'fetch':
379-
fetch_targets_names.append(op.desc.input('X')[0])
380-
return fetch_targets_names
381-
382-
383372
def load_inference_model(dirname,
384373
executor,
385374
model_filename=None,
@@ -418,8 +407,8 @@ def load_inference_model(dirname,
418407
program = Program.parse_from_string(program_desc_str)
419408
load_persistables(executor, dirname, program, params_filename)
420409

421-
feed_target_names = get_feed_targets_names(program)
422-
fetch_target_names = get_fetch_targets_names(program)
410+
feed_target_names = program.desc.get_feed_target_names()
411+
fetch_target_names = program.desc.get_fetch_target_names()
423412
fetch_targets = [
424413
program.global_block().var(name) for name in fetch_target_names
425414
]

python/paddle/fluid/tests/book/test_image_classification.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,10 @@ def infer(use_cuda, save_dirname=None):
248248

249249
print("infer results: ", results[0])
250250

251+
fluid.io.save_inference_model(save_dirname, feed_target_names,
252+
fetch_targets, exe,
253+
inference_transpiler_program)
254+
251255

252256
def main(net_type, use_cuda, is_local=True):
253257
if use_cuda and not fluid.core.is_compiled_with_cuda():

0 commit comments

Comments
 (0)