Skip to content

Commit d77e6a6

Browse files
authored
Merge pull request #7636 from kexinzhao/save_inference_model
Add feed and fetch op to ProgramDesc before saving for inference
2 parents 7905e36 + 856f650 commit d77e6a6

File tree

5 files changed

+64
-47
lines changed

5 files changed

+64
-47
lines changed

paddle/inference/CMakeLists.txt

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,6 @@ cc_library(paddle_fluid_api
88
# Merge all modules into a simgle static library
99
cc_library(paddle_fluid DEPS paddle_fluid_api ${FLUID_CORE_MODULES})
1010

11-
# ptools
12-
# just for testing, we may need to change the storing format for inference_model
13-
# and move the dependent of pickle.
14-
# download from http://www.picklingtools.com/
15-
# build in the C++ sub-directory, using command
16-
# make -f Makefile.Linux libptools.so
17-
set(PTOOLS_LIB)
18-
set(PTOOLS_ROOT $ENV{PTOOLS_ROOT} CACHE PATH "Folder contains PicklingTools")
19-
find_path(PTOOLS_INC_DIR chooseser.h PATHS ${PTOOLS_ROOT}/C++)
20-
find_library(PTOOLS_SHARED_LIB NAMES ptools PATHS ${PTOOLS_ROOT}/C++)
21-
if(PTOOLS_INC_DIR AND PTOOLS_SHARED_LIB)
22-
add_definitions(-DPADDLE_USE_PTOOLS)
23-
set(PTOOLS_LIB ptools)
24-
message(STATUS "Found PicklingTools: ${PTOOLS_SHARED_LIB}")
25-
add_library(${PTOOLS_LIB} SHARED IMPORTED GLOBAL)
26-
set_property(TARGET ${PTOOLS_LIB} PROPERTY IMPORTED_LOCATION ${PTOOLS_SHARED_LIB})
27-
include_directories(${PTOOLS_ROOT}/C++)
28-
include_directories(${PTOOLS_ROOT}/C++/opencontainers_1_8_5/include)
29-
add_definitions(-DOC_NEW_STYLE_INCLUDES) # used in ptools
30-
endif()
31-
3211
add_executable(example example.cc)
3312
if(APPLE)
3413
set(OPTIONAL_LINK_FLAGS)

paddle/inference/example.cc

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,33 +18,21 @@ limitations under the License. */
1818
#include "paddle/inference/inference.h"
1919

2020
DEFINE_string(dirname, "", "Directory of the inference model.");
21-
DEFINE_string(feed_var_names, "", "Names of feeding variables");
22-
DEFINE_string(fetch_var_names, "", "Names of fetching variables");
2321

2422
int main(int argc, char** argv) {
2523
google::ParseCommandLineFlags(&argc, &argv, true);
26-
if (FLAGS_dirname.empty() || FLAGS_feed_var_names.empty() ||
27-
FLAGS_fetch_var_names.empty()) {
24+
if (FLAGS_dirname.empty()) {
2825
// Example:
2926
// ./example --dirname=recognize_digits_mlp.inference.model
30-
// --feed_var_names="x"
31-
// --fetch_var_names="fc_2.tmp_2"
32-
std::cout << "Usage: ./example --dirname=path/to/your/model "
33-
"--feed_var_names=x --fetch_var_names=y"
34-
<< std::endl;
27+
std::cout << "Usage: ./example --dirname=path/to/your/model" << std::endl;
3528
exit(1);
3629
}
3730

3831
std::cout << "FLAGS_dirname: " << FLAGS_dirname << std::endl;
39-
std::cout << "FLAGS_feed_var_names: " << FLAGS_feed_var_names << std::endl;
40-
std::cout << "FLAGS_fetch_var_names: " << FLAGS_fetch_var_names << std::endl;
41-
4232
std::string dirname = FLAGS_dirname;
43-
std::vector<std::string> feed_var_names = {FLAGS_feed_var_names};
44-
std::vector<std::string> fetch_var_names = {FLAGS_fetch_var_names};
4533

4634
paddle::InferenceEngine* engine = new paddle::InferenceEngine();
47-
engine->LoadInferenceModel(dirname, feed_var_names, fetch_var_names);
35+
engine->LoadInferenceModel(dirname);
4836

4937
paddle::framework::LoDTensor input;
5038
srand(time(0));

paddle/inference/inference.cc

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,37 @@ limitations under the License. */
2525

2626
namespace paddle {
2727

28+
void InferenceEngine::LoadInferenceModel(const std::string& dirname) {
29+
std::string model_filename = dirname + "/__model__.dat";
30+
LOG(INFO) << "loading model from " << model_filename;
31+
std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary);
32+
std::string program_desc_str;
33+
inputfs.seekg(0, std::ios::end);
34+
program_desc_str.resize(inputfs.tellg());
35+
inputfs.seekg(0, std::ios::beg);
36+
LOG(INFO) << "program_desc_str's size: " << program_desc_str.size();
37+
inputfs.read(&program_desc_str[0], program_desc_str.size());
38+
inputfs.close();
39+
40+
program_ = new framework::ProgramDesc(program_desc_str);
41+
GenerateLoadProgram(dirname);
42+
43+
framework::BlockDesc* global_block = program_->MutableBlock(0);
44+
feed_var_names_.clear();
45+
fetch_var_names_.clear();
46+
for (auto* op : global_block->AllOps()) {
47+
if (op->Type() == "feed") {
48+
feed_var_names_.insert(feed_var_names_.begin(), op->Output("Out")[0]);
49+
} else if (op->Type() == "fetch") {
50+
fetch_var_names_.push_back(op->Input("X")[0]);
51+
}
52+
}
53+
}
54+
2855
void InferenceEngine::LoadInferenceModel(
2956
const std::string& dirname,
3057
const std::vector<std::string>& feed_var_names,
3158
const std::vector<std::string>& fetch_var_names) {
32-
#ifdef PADDLE_USE_PTOOLS
33-
std::string model_filename = dirname + "/__model__";
34-
LOG(INFO) << "Using PicklingTools, loading model from " << model_filename;
35-
Val v;
36-
LoadValFromFile(model_filename.c_str(), v, SERIALIZE_P0);
37-
std::string program_desc_str = v["program_desc_str"];
38-
LOG(INFO) << "program_desc_str's size: " << program_desc_str.size();
39-
// PicklingTools cannot parse the vector of strings correctly.
40-
#else
4159
std::string model_filename = dirname + "/__model__.dat";
4260
LOG(INFO) << "loading model from " << model_filename;
4361
std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary);
@@ -48,7 +66,7 @@ void InferenceEngine::LoadInferenceModel(
4866
LOG(INFO) << "program_desc_str's size: " << program_desc_str.size();
4967
inputfs.read(&program_desc_str[0], program_desc_str.size());
5068
inputfs.close();
51-
#endif
69+
5270
program_ = new framework::ProgramDesc(program_desc_str);
5371
GenerateLoadProgram(dirname);
5472

@@ -62,7 +80,7 @@ void InferenceEngine::LoadInferenceModel(
6280
}
6381

6482
bool InferenceEngine::IsParameter(const framework::VarDesc* var) {
65-
if (var->Persistable()) {
83+
if (var->Persistable() && var->Name() != "feed" && var->Name() != "fetch") {
6684
// There are many unreachable variables in the program
6785
for (size_t i = 0; i < program_->Size(); ++i) {
6886
const framework::BlockDesc& block = program_->Block(i);

paddle/inference/inference.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class InferenceEngine {
2828
delete load_program_;
2929
}
3030

31+
void LoadInferenceModel(const std::string& dirname);
3132
void LoadInferenceModel(const std::string& dirname,
3233
const std::vector<std::string>& feed_var_names,
3334
const std::vector<std::string>& fetch_var_names);

python/paddle/v2/fluid/io.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import cPickle as pickle
1616

1717
from paddle.v2.fluid.framework import Program, Parameter, default_main_program, Variable
18+
from . import core
1819

1920
__all__ = [
2021
'save_vars',
@@ -191,6 +192,33 @@ def get_inference_program(target_vars, main_program=None):
191192
return inference_program
192193

193194

195+
def prepend_feed_ops(inference_program, feeded_var_names):
196+
global_block = inference_program.global_block()
197+
feed_var = global_block.create_var(
198+
name='feed', type=core.VarDesc.VarType.FEED_MINIBATCH, persistable=True)
199+
200+
for i, name in enumerate(feeded_var_names):
201+
out = global_block.var(name)
202+
global_block.prepend_op(
203+
type='feed',
204+
inputs={'X': [feed_var]},
205+
outputs={'Out': [out]},
206+
attrs={'col': i})
207+
208+
209+
def append_fetch_ops(inference_program, fetch_var_names):
210+
global_block = inference_program.global_block()
211+
fetch_var = global_block.create_var(
212+
name='fetch', type=core.VarDesc.VarType.FETCH_LIST, persistable=True)
213+
214+
for i, name in enumerate(fetch_var_names):
215+
global_block.append_op(
216+
type='fetch',
217+
inputs={'X': [name]},
218+
outputs={'Out': [fetch_var]},
219+
attrs={'col': i})
220+
221+
194222
def save_inference_model(dirname,
195223
feeded_var_names,
196224
target_vars,
@@ -241,6 +269,9 @@ def save_inference_model(dirname,
241269
"fetch_var_names": fetch_var_names
242270
}, f, -1)
243271

272+
prepend_feed_ops(inference_program, feeded_var_names)
273+
append_fetch_ops(inference_program, fetch_var_names)
274+
244275
# Save only programDesc of inference_program in binary format
245276
# in another file: __model__.dat
246277
with open(model_file_name + ".dat", "wb") as fp:

0 commit comments

Comments
 (0)