Skip to content

Commit 2d74b5f

Browse files
committed
Refine the Python API load/save_inference_model.
1 parent b44917d commit 2d74b5f

File tree

3 files changed

+83
-55
lines changed

3 files changed

+83
-55
lines changed

paddle/fluid/inference/tests/test_helper.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ void TestInference(const std::string& dirname,
101101
if (IsCombined) {
102102
// All parameters are saved in a single file.
103103
// Hard-coding the file names of program and parameters in unittest.
104-
// Users are free to specify different filename
105-
// (provided: the filenames are changed in the python api as well: io.py)
104+
// The file names should be consistent with that used in Python API
105+
// `fluid.io.save_inference_model`.
106106
std::string prog_filename = "__model_combined__";
107107
std::string param_filename = "__params_combined__";
108108
inference_program = paddle::inference::Load(executor,

python/paddle/v2/fluid/io.py

Lines changed: 57 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def save_vars(executor,
6868
main_program=None,
6969
vars=None,
7070
predicate=None,
71-
save_file_name=None):
71+
filename=None):
7272
"""
7373
Save variables to directory by executor.
7474
@@ -80,8 +80,8 @@ def save_vars(executor,
8080
as a bool. If it returns true, the corresponding input variable will be saved.
8181
:param vars: variables need to be saved. If vars is specified, program & predicate
8282
will be ignored
83-
:param save_file_name: The name of a single file that all vars are saved to.
84-
If it is None, save variables to separate files.
83+
:param filename: The name of a single file that all vars are saved to.
84+
If it is None, save variables to separate files.
8585
8686
:return: None
8787
"""
@@ -95,15 +95,15 @@ def save_vars(executor,
9595
executor,
9696
dirname=dirname,
9797
vars=filter(predicate, main_program.list_vars()),
98-
save_file_name=save_file_name)
98+
filename=filename)
9999
else:
100100
save_program = Program()
101101
save_block = save_program.global_block()
102102

103103
save_var_map = {}
104104
for each_var in vars:
105105
new_var = _clone_var_in_block_(save_block, each_var)
106-
if save_file_name is None:
106+
if filename is None:
107107
save_block.append_op(
108108
type='save',
109109
inputs={'X': [new_var]},
@@ -112,7 +112,7 @@ def save_vars(executor,
112112
else:
113113
save_var_map[new_var.name] = new_var
114114

115-
if save_file_name is not None:
115+
if filename is not None:
116116
save_var_list = []
117117
for name in sorted(save_var_map.keys()):
118118
save_var_list.append(save_var_map[name])
@@ -121,12 +121,12 @@ def save_vars(executor,
121121
type='save_combine',
122122
inputs={'X': save_var_list},
123123
outputs={},
124-
attrs={'file_path': os.path.join(dirname, save_file_name)})
124+
attrs={'file_path': os.path.join(dirname, filename)})
125125

126126
executor.run(save_program)
127127

128128

129-
def save_params(executor, dirname, main_program=None, save_file_name=None):
129+
def save_params(executor, dirname, main_program=None, filename=None):
130130
"""
131131
Save all parameters to directory with executor.
132132
"""
@@ -136,11 +136,10 @@ def save_params(executor, dirname, main_program=None, save_file_name=None):
136136
main_program=main_program,
137137
vars=None,
138138
predicate=is_parameter,
139-
save_file_name=save_file_name)
139+
filename=filename)
140140

141141

142-
def save_persistables(executor, dirname, main_program=None,
143-
save_file_name=None):
142+
def save_persistables(executor, dirname, main_program=None, filename=None):
144143
"""
145144
Save all persistables to directory with executor.
146145
"""
@@ -150,15 +149,15 @@ def save_persistables(executor, dirname, main_program=None,
150149
main_program=main_program,
151150
vars=None,
152151
predicate=is_persistable,
153-
save_file_name=save_file_name)
152+
filename=filename)
154153

155154

156155
def load_vars(executor,
157156
dirname,
158157
main_program=None,
159158
vars=None,
160159
predicate=None,
161-
load_file_name=None):
160+
filename=None):
162161
"""
163162
Load variables from directory by executor.
164163
@@ -170,8 +169,8 @@ def load_vars(executor,
170169
as a bool. If it returns true, the corresponding input variable will be loaded.
171170
:param vars: variables need to be loaded. If vars is specified, program &
172171
predicate will be ignored
173-
:param load_file_name: The name of the single file that all vars are loaded from.
174-
If it is None, load variables from separate files.
172+
:param filename: The name of the single file that all vars are loaded from.
173+
If it is None, load variables from separate files.
175174
176175
:return: None
177176
"""
@@ -185,7 +184,7 @@ def load_vars(executor,
185184
executor,
186185
dirname=dirname,
187186
vars=filter(predicate, main_program.list_vars()),
188-
load_file_name=load_file_name)
187+
filename=filename)
189188
else:
190189
load_prog = Program()
191190
load_block = load_prog.global_block()
@@ -194,7 +193,7 @@ def load_vars(executor,
194193
for each_var in vars:
195194
assert isinstance(each_var, Variable)
196195
new_var = _clone_var_in_block_(load_block, each_var)
197-
if load_file_name is None:
196+
if filename is None:
198197
load_block.append_op(
199198
type='load',
200199
inputs={},
@@ -203,7 +202,7 @@ def load_vars(executor,
203202
else:
204203
load_var_map[new_var.name] = new_var
205204

206-
if load_file_name is not None:
205+
if filename is not None:
207206
load_var_list = []
208207
for name in sorted(load_var_map.keys()):
209208
load_var_list.append(load_var_map[name])
@@ -212,12 +211,12 @@ def load_vars(executor,
212211
type='load_combine',
213212
inputs={},
214213
outputs={"Out": load_var_list},
215-
attrs={'file_path': os.path.join(dirname, load_file_name)})
214+
attrs={'file_path': os.path.join(dirname, filename)})
216215

217216
executor.run(load_prog)
218217

219218

220-
def load_params(executor, dirname, main_program=None, load_file_name=None):
219+
def load_params(executor, dirname, main_program=None, filename=None):
221220
"""
222221
load all parameters from directory by executor.
223222
"""
@@ -226,11 +225,10 @@ def load_params(executor, dirname, main_program=None, load_file_name=None):
226225
dirname=dirname,
227226
main_program=main_program,
228227
predicate=is_parameter,
229-
load_file_name=load_file_name)
228+
filename=filename)
230229

231230

232-
def load_persistables(executor, dirname, main_program=None,
233-
load_file_name=None):
231+
def load_persistables(executor, dirname, main_program=None, filename=None):
234232
"""
235233
load all persistables from directory by executor.
236234
"""
@@ -239,7 +237,7 @@ def load_persistables(executor, dirname, main_program=None,
239237
dirname=dirname,
240238
main_program=main_program,
241239
predicate=is_persistable,
242-
load_file_name=load_file_name)
240+
filename=filename)
243241

244242

245243
def get_inference_program(target_vars, main_program=None):
@@ -299,7 +297,8 @@ def save_inference_model(dirname,
299297
target_vars,
300298
executor,
301299
main_program=None,
302-
save_file_name=None):
300+
model_filename=None,
301+
params_filename=None):
303302
"""
304303
Build a model especially for inference,
305304
and save it to directory by the executor.
@@ -310,8 +309,11 @@ def save_inference_model(dirname,
310309
:param executor: executor that save inference model
311310
:param main_program: original program, which will be pruned to build the inference model.
312311
Default default_main_program().
313-
:param save_file_name: The name of a single file that all parameters are saved to.
314-
If it is None, save parameters to separate files.
312+
:param model_filename: The name of file to save inference program.
313+
If not specified, default filename `__model__` will be used.
314+
:param params_filename: The name of file to save parameters.
315+
It is used for the case that all parameters are saved in a single binary file.
316+
If not specified, parameters are considered saved in separate files.
315317
316318
:return: None
317319
"""
@@ -342,15 +344,19 @@ def save_inference_model(dirname,
342344
prepend_feed_ops(inference_program, feeded_var_names)
343345
append_fetch_ops(inference_program, fetch_var_names)
344346

345-
if save_file_name == None:
346-
model_file_name = dirname + "/__model__"
347+
if model_filename is not None:
348+
model_filename = os.path.basename(model_filename)
347349
else:
348-
model_file_name = dirname + "/__model_combined__"
350+
model_filename = "__model__"
351+
model_filename = os.path.join(dirname, model_filename)
349352

350-
with open(model_file_name, "wb") as f:
353+
if params_filename is not None:
354+
params_filename = os.path.basename(params_filename)
355+
356+
with open(model_filename, "wb") as f:
351357
f.write(inference_program.desc.serialize_to_string())
352358

353-
save_persistables(executor, dirname, inference_program, save_file_name)
359+
save_persistables(executor, dirname, inference_program, params_filename)
354360

355361

356362
def get_feed_targets_names(program):
@@ -371,15 +377,21 @@ def get_fetch_targets_names(program):
371377
return fetch_targets_names
372378

373379

374-
def load_inference_model(dirname, executor, load_file_name=None):
380+
def load_inference_model(dirname,
381+
executor,
382+
model_filename=None,
383+
params_filename=None):
375384
"""
376385
Load inference model from a directory
377386
378387
:param dirname: directory path
379388
:param executor: executor that load inference model
380-
:param load_file_name: The name of the single file that all parameters are loaded from.
381-
If it is None, load parameters from separate files.
382-
389+
:param model_filename: The name of file to load inference program.
390+
If not specified, default filename `__model__` will be used.
391+
:param params_filename: The name of file to load parameters.
392+
It is used for the case that all parameters are saved in a single binary file.
393+
If not specified, parameters are considered saved in separate files.
394+
383395
:return: [program, feed_target_names, fetch_targets]
384396
program: program especially for inference.
385397
feed_target_names: Names of variables that need to feed data
@@ -388,16 +400,20 @@ def load_inference_model(dirname, executor, load_file_name=None):
388400
if not os.path.isdir(dirname):
389401
raise ValueError("There is no directory named '%s'", dirname)
390402

391-
if load_file_name == None:
392-
model_file_name = dirname + "/__model__"
403+
if model_filename is not None:
404+
model_filename = os.path.basename(model_filename)
393405
else:
394-
model_file_name = dirname + "/__model_combined__"
406+
model_filename = "__model__"
407+
model_filename = os.path.join(dirname, model_filename)
408+
409+
if params_filename is not None:
410+
params_filename = os.path.basename(params_filename)
395411

396-
with open(model_file_name, "rb") as f:
412+
with open(model_filename, "rb") as f:
397413
program_desc_str = f.read()
398414

399415
program = Program.parse_from_string(program_desc_str)
400-
load_persistables(executor, dirname, program, load_file_name)
416+
load_persistables(executor, dirname, program, params_filename)
401417

402418
feed_target_names = get_feed_targets_names(program)
403419
fetch_target_names = get_fetch_targets_names(program)

python/paddle/v2/fluid/tests/book/test_recognize_digits.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,12 @@ def conv_net(img, label):
7878
return loss_net(conv_pool_2, label)
7979

8080

81-
def train(nn_type, use_cuda, parallel, save_dirname, save_param_filename):
81+
def train(nn_type,
82+
use_cuda,
83+
parallel,
84+
save_dirname=None,
85+
model_filename=None,
86+
params_filename=None):
8287
if use_cuda and not fluid.core.is_compiled_with_cuda():
8388
return
8489
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
@@ -146,7 +151,8 @@ def train(nn_type, use_cuda, parallel, save_dirname, save_param_filename):
146151
fluid.io.save_inference_model(
147152
save_dirname, ["img"], [prediction],
148153
exe,
149-
save_file_name=save_param_filename)
154+
model_filename=model_filename,
155+
params_filename=params_filename)
150156
return
151157
else:
152158
print(
@@ -158,7 +164,10 @@ def train(nn_type, use_cuda, parallel, save_dirname, save_param_filename):
158164
raise AssertionError("Loss of recognize digits is too large")
159165

160166

161-
def infer(use_cuda, save_dirname=None, param_filename=None):
167+
def infer(use_cuda,
168+
save_dirname=None,
169+
model_filename=None,
170+
params_filename=None):
162171
if save_dirname is None:
163172
return
164173

@@ -171,8 +180,9 @@ def infer(use_cuda, save_dirname=None, param_filename=None):
171180
# the feed_target_names (the names of variables that will be feeded
172181
# data using feed operators), and the fetch_targets (variables that
173182
# we want to obtain data from using fetch operators).
174-
[inference_program, feed_target_names, fetch_targets
175-
] = fluid.io.load_inference_model(save_dirname, exe, param_filename)
183+
[inference_program, feed_target_names,
184+
fetch_targets] = fluid.io.load_inference_model(
185+
save_dirname, exe, model_filename, params_filename)
176186

177187
# The input's dimension of conv should be 4-D or 5-D.
178188
# Use normilized image pixels as input data, which should be in the range [-1.0, 1.0].
@@ -189,25 +199,27 @@ def infer(use_cuda, save_dirname=None, param_filename=None):
189199

190200

191201
def main(use_cuda, parallel, nn_type, combine):
202+
save_dirname = None
203+
model_filename = None
204+
params_filename = None
192205
if not use_cuda and not parallel:
193206
save_dirname = "recognize_digits_" + nn_type + ".inference.model"
194-
save_filename = None
195207
if combine == True:
196-
save_filename = "__params_combined__"
197-
else:
198-
save_dirname = None
199-
save_filename = None
208+
model_filename = "__model_combined__"
209+
params_filename = "__params_combined__"
200210

201211
train(
202212
nn_type=nn_type,
203213
use_cuda=use_cuda,
204214
parallel=parallel,
205215
save_dirname=save_dirname,
206-
save_param_filename=save_filename)
216+
model_filename=model_filename,
217+
params_filename=params_filename)
207218
infer(
208219
use_cuda=use_cuda,
209220
save_dirname=save_dirname,
210-
param_filename=save_filename)
221+
model_filename=model_filename,
222+
params_filename=params_filename)
211223

212224

213225
class TestRecognizeDigits(unittest.TestCase):

0 commit comments

Comments
 (0)