Skip to content

Commit 6783dce

Browse files
authored
Python API for inference model saving/load (#5020)
* Add `dump_to_file()` for ProgrameDescBind in pybind * Update * Add utility.py * typo * Fix bugs * Move add_feed/fetch_components to untility.py * Compelete dump * Follow comments * Change output of Prune() from inference to pointer * Expose Prune() to Python * Compelete save/load API of inference model * Fix errors * Debuging * Compelete unit tests * follow comments
1 parent f3ac4d8 commit 6783dce

File tree

15 files changed

+268
-18
lines changed

15 files changed

+268
-18
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,4 @@ cmake_install.cmake
2828
paddle/.timestamp
2929
python/paddlepaddle.egg-info/
3030
paddle/pybind/pybind.h
31+
python/paddle/v2/framework/tests/tmp/*

paddle/framework/op_desc.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ class OpDescBind {
107107

108108
void InferVarType(BlockDescBind *block) const;
109109

110+
void MarkAsTarget() { desc_.set_is_target(true); }
111+
110112
void Flush();
111113

112114
private:

paddle/framework/program_desc.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,13 @@ ProgramDescBind::ProgramDescBind(const ProgramDescBind &o) {
4949
}
5050
}
5151

52+
ProgramDescBind::ProgramDescBind(const ProgramDesc &desc) {
53+
desc_ = desc;
54+
for (auto &block_desc : *desc_.mutable_blocks()) {
55+
blocks_.emplace_back(new BlockDescBind(this, &block_desc));
56+
}
57+
}
58+
5259
ProgramDescBind::ProgramDescBind(const std::string &binary_str) {
5360
PADDLE_ENFORCE(desc_.ParseFromString(binary_str),
5461
"Fail to parse program_desc from binary string.");

paddle/framework/program_desc.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class ProgramDescBind {
2929
public:
3030
ProgramDescBind();
3131

32+
explicit ProgramDescBind(const ProgramDesc &desc);
33+
3234
ProgramDescBind(const ProgramDescBind &o);
3335

3436
explicit ProgramDescBind(const std::string &binary_str);

paddle/framework/prune.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ bool IsTarget(const OpDesc& op_desc) {
4646
return false;
4747
}
4848

49-
void prune_impl(const ProgramDesc& input, ProgramDesc& output, int block_id) {
49+
void prune_impl(const ProgramDesc& input, ProgramDesc* output, int block_id) {
5050
// TODO(tonyyang-svail):
5151
// - will change to use multiple blocks for RNN op and Cond Op
5252

@@ -91,8 +91,8 @@ void prune_impl(const ProgramDesc& input, ProgramDesc& output, int block_id) {
9191
// we reverse the should_run vector
9292
std::reverse(should_run.begin(), should_run.end());
9393

94-
output = input;
95-
auto* op_field = output.mutable_blocks(block_id)->mutable_ops();
94+
*output = input;
95+
auto* op_field = output->mutable_blocks(block_id)->mutable_ops();
9696
op_field->Clear();
9797
for (size_t i = 0; i < should_run.size(); ++i) {
9898
if (should_run[i]) {
@@ -101,7 +101,8 @@ void prune_impl(const ProgramDesc& input, ProgramDesc& output, int block_id) {
101101
}
102102
}
103103

104-
void Prune(const ProgramDesc& input, ProgramDesc& output) {
104+
// TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies
105+
void Prune(const ProgramDesc& input, ProgramDesc* output) {
105106
prune_impl(input, output, 0);
106107
}
107108

paddle/framework/prune.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ limitations under the License. */
2020
namespace paddle {
2121
namespace framework {
2222

23-
void Prune(const ProgramDesc& input, ProgramDesc& output);
23+
void Prune(const ProgramDesc& input, ProgramDesc* output);
2424

2525
} // namespace framework
2626
} // namespace paddle

paddle/framework/prune_test.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,11 @@ TEST(Prune, one_operator) {
5959
f::ProgramDesc *pdesc = program.Proto();
6060
f::ProgramDesc pruned;
6161

62-
Prune(*pdesc, pruned);
62+
Prune(*pdesc, &pruned);
6363
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0);
6464

6565
pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true);
66-
Prune(*pdesc, pruned);
66+
Prune(*pdesc, &pruned);
6767
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1);
6868
}
6969

@@ -81,7 +81,7 @@ TEST(Prune, forward) {
8181
for (int i = 0; i < pdesc->blocks(0).ops_size(); ++i) {
8282
f::ProgramDesc pruned;
8383
pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true);
84-
Prune(*pdesc, pruned);
84+
Prune(*pdesc, &pruned);
8585
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1);
8686
}
8787
}
@@ -100,7 +100,7 @@ TEST(Prune, multi_input_op) {
100100
pdesc->mutable_blocks(0)->mutable_ops(3)->set_is_target(true);
101101

102102
f::ProgramDesc pruned;
103-
Prune(*pdesc, pruned);
103+
Prune(*pdesc, &pruned);
104104
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4);
105105
}
106106

@@ -116,7 +116,7 @@ TEST(Prune, multi_output_op) {
116116
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
117117

118118
f::ProgramDesc pruned;
119-
Prune(*pdesc, pruned);
119+
Prune(*pdesc, &pruned);
120120
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2);
121121
}
122122

@@ -133,6 +133,6 @@ TEST(Prune, multi_target) {
133133
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
134134

135135
f::ProgramDesc pruned;
136-
Prune(*pdesc, pruned);
136+
Prune(*pdesc, &pruned);
137137
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3);
138138
}

paddle/pybind/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
if(WITH_PYTHON)
22
cc_library(paddle_pybind SHARED
33
SRCS pybind.cc exception.cc protobuf.cc
4-
DEPS pybind python backward proto_desc tensor_array paddle_memory executor
4+
DEPS pybind python backward proto_desc tensor_array paddle_memory executor prune
55
${GLOB_OP_LIB})
66
endif(WITH_PYTHON)
77

paddle/pybind/protobuf.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,13 @@ void BindProgramDesc(py::module &m) {
141141
desc->SerializeToString(&res),
142142
"Serialize ProgramDesc Error. This could be a bug of Paddle.");
143143
return res;
144+
})
145+
.def("parse_from_string",
146+
[](ProgramDescBind &program_desc, const std::string &data) {
147+
ProgramDesc *desc = program_desc.Proto();
148+
PADDLE_ENFORCE(desc->ParseFromString(data),
149+
"Fail to parse ProgramDesc from string. This could "
150+
"be a bug of Paddle.");
144151
});
145152
}
146153

paddle/pybind/pybind.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License. */
1919
#include "paddle/framework/feed_fetch_method.h"
2020
#include "paddle/framework/framework.pb.h"
2121
#include "paddle/framework/lod_tensor.h"
22+
#include "paddle/framework/prune.h"
2223
#include "paddle/framework/selected_rows.h"
2324
#include "paddle/framework/tensor_array.h"
2425
#include "paddle/operators/cond_op.h"
@@ -237,6 +238,16 @@ All parameter, weight, gradient are variables in Paddle.
237238
}
238239
return ret_values;
239240
});
241+
m.def("prune", [](const ProgramDescBind &origin,
242+
const std::vector<std::array<size_t, 2>> &targets) {
243+
ProgramDescBind prog_with_targets(origin);
244+
for (const auto &t : targets) {
245+
prog_with_targets.Block(t[0])->Op(t[1])->MarkAsTarget();
246+
}
247+
ProgramDesc pruned_desc;
248+
Prune(*prog_with_targets.Proto(), &pruned_desc);
249+
return new ProgramDescBind(pruned_desc);
250+
});
240251
m.def_submodule(
241252
"var_names",
242253
"The module will return special predefined variable name in Paddle")

0 commit comments

Comments
 (0)