@@ -68,7 +68,7 @@ def save_vars(executor,
68
68
main_program = None ,
69
69
vars = None ,
70
70
predicate = None ,
71
- save_file_name = None ):
71
+ filename = None ):
72
72
"""
73
73
Save variables to directory by executor.
74
74
@@ -80,8 +80,8 @@ def save_vars(executor,
80
80
as a bool. If it returns true, the corresponding input variable will be saved.
81
81
:param vars: variables need to be saved. If vars is specified, program & predicate
82
82
will be ignored
83
- :param save_file_name : The name of a single file that all vars are saved to.
84
- If it is None, save variables to separate files.
83
+ :param filename : The name of a single file that all vars are saved to.
84
+ If it is None, save variables to separate files.
85
85
86
86
:return: None
87
87
"""
@@ -95,15 +95,15 @@ def save_vars(executor,
95
95
executor ,
96
96
dirname = dirname ,
97
97
vars = filter (predicate , main_program .list_vars ()),
98
- save_file_name = save_file_name )
98
+ filename = filename )
99
99
else :
100
100
save_program = Program ()
101
101
save_block = save_program .global_block ()
102
102
103
103
save_var_map = {}
104
104
for each_var in vars :
105
105
new_var = _clone_var_in_block_ (save_block , each_var )
106
- if save_file_name is None :
106
+ if filename is None :
107
107
save_block .append_op (
108
108
type = 'save' ,
109
109
inputs = {'X' : [new_var ]},
@@ -112,7 +112,7 @@ def save_vars(executor,
112
112
else :
113
113
save_var_map [new_var .name ] = new_var
114
114
115
- if save_file_name is not None :
115
+ if filename is not None :
116
116
save_var_list = []
117
117
for name in sorted (save_var_map .keys ()):
118
118
save_var_list .append (save_var_map [name ])
@@ -121,12 +121,12 @@ def save_vars(executor,
121
121
type = 'save_combine' ,
122
122
inputs = {'X' : save_var_list },
123
123
outputs = {},
124
- attrs = {'file_path' : os .path .join (dirname , save_file_name )})
124
+ attrs = {'file_path' : os .path .join (dirname , filename )})
125
125
126
126
executor .run (save_program )
127
127
128
128
129
- def save_params (executor , dirname , main_program = None , save_file_name = None ):
129
+ def save_params (executor , dirname , main_program = None , filename = None ):
130
130
"""
131
131
Save all parameters to directory with executor.
132
132
"""
@@ -136,11 +136,10 @@ def save_params(executor, dirname, main_program=None, save_file_name=None):
136
136
main_program = main_program ,
137
137
vars = None ,
138
138
predicate = is_parameter ,
139
- save_file_name = save_file_name )
139
+ filename = filename )
140
140
141
141
142
- def save_persistables (executor , dirname , main_program = None ,
143
- save_file_name = None ):
142
+ def save_persistables (executor , dirname , main_program = None , filename = None ):
144
143
"""
145
144
Save all persistables to directory with executor.
146
145
"""
@@ -150,15 +149,15 @@ def save_persistables(executor, dirname, main_program=None,
150
149
main_program = main_program ,
151
150
vars = None ,
152
151
predicate = is_persistable ,
153
- save_file_name = save_file_name )
152
+ filename = filename )
154
153
155
154
156
155
def load_vars (executor ,
157
156
dirname ,
158
157
main_program = None ,
159
158
vars = None ,
160
159
predicate = None ,
161
- load_file_name = None ):
160
+ filename = None ):
162
161
"""
163
162
Load variables from directory by executor.
164
163
@@ -170,8 +169,8 @@ def load_vars(executor,
170
169
as a bool. If it returns true, the corresponding input variable will be loaded.
171
170
:param vars: variables need to be loaded. If vars is specified, program &
172
171
predicate will be ignored
173
- :param load_file_name : The name of the single file that all vars are loaded from.
174
- If it is None, load variables from separate files.
172
+ :param filename : The name of the single file that all vars are loaded from.
173
+ If it is None, load variables from separate files.
175
174
176
175
:return: None
177
176
"""
@@ -185,7 +184,7 @@ def load_vars(executor,
185
184
executor ,
186
185
dirname = dirname ,
187
186
vars = filter (predicate , main_program .list_vars ()),
188
- load_file_name = load_file_name )
187
+ filename = filename )
189
188
else :
190
189
load_prog = Program ()
191
190
load_block = load_prog .global_block ()
@@ -194,7 +193,7 @@ def load_vars(executor,
194
193
for each_var in vars :
195
194
assert isinstance (each_var , Variable )
196
195
new_var = _clone_var_in_block_ (load_block , each_var )
197
- if load_file_name is None :
196
+ if filename is None :
198
197
load_block .append_op (
199
198
type = 'load' ,
200
199
inputs = {},
@@ -203,7 +202,7 @@ def load_vars(executor,
203
202
else :
204
203
load_var_map [new_var .name ] = new_var
205
204
206
- if load_file_name is not None :
205
+ if filename is not None :
207
206
load_var_list = []
208
207
for name in sorted (load_var_map .keys ()):
209
208
load_var_list .append (load_var_map [name ])
@@ -212,12 +211,12 @@ def load_vars(executor,
212
211
type = 'load_combine' ,
213
212
inputs = {},
214
213
outputs = {"Out" : load_var_list },
215
- attrs = {'file_path' : os .path .join (dirname , load_file_name )})
214
+ attrs = {'file_path' : os .path .join (dirname , filename )})
216
215
217
216
executor .run (load_prog )
218
217
219
218
220
- def load_params (executor , dirname , main_program = None , load_file_name = None ):
219
+ def load_params (executor , dirname , main_program = None , filename = None ):
221
220
"""
222
221
load all parameters from directory by executor.
223
222
"""
@@ -226,11 +225,10 @@ def load_params(executor, dirname, main_program=None, load_file_name=None):
226
225
dirname = dirname ,
227
226
main_program = main_program ,
228
227
predicate = is_parameter ,
229
- load_file_name = load_file_name )
228
+ filename = filename )
230
229
231
230
232
- def load_persistables (executor , dirname , main_program = None ,
233
- load_file_name = None ):
231
+ def load_persistables (executor , dirname , main_program = None , filename = None ):
234
232
"""
235
233
load all persistables from directory by executor.
236
234
"""
@@ -239,7 +237,7 @@ def load_persistables(executor, dirname, main_program=None,
239
237
dirname = dirname ,
240
238
main_program = main_program ,
241
239
predicate = is_persistable ,
242
- load_file_name = load_file_name )
240
+ filename = filename )
243
241
244
242
245
243
def get_inference_program (target_vars , main_program = None ):
@@ -299,7 +297,8 @@ def save_inference_model(dirname,
299
297
target_vars ,
300
298
executor ,
301
299
main_program = None ,
302
- save_file_name = None ):
300
+ model_filename = None ,
301
+ params_filename = None ):
303
302
"""
304
303
Build a model especially for inference,
305
304
and save it to directory by the executor.
@@ -310,8 +309,11 @@ def save_inference_model(dirname,
310
309
:param executor: executor that save inference model
311
310
:param main_program: original program, which will be pruned to build the inference model.
312
311
Default default_main_program().
313
- :param save_file_name: The name of a single file that all parameters are saved to.
314
- If it is None, save parameters to separate files.
312
+ :param model_filename: The name of file to save inference program.
313
+ If not specified, default filename `__model__` will be used.
314
+ :param params_filename: The name of file to save parameters.
315
+ It is used for the case that all parameters are saved in a single binary file.
316
+ If not specified, parameters are considered saved in separate files.
315
317
316
318
:return: None
317
319
"""
@@ -342,15 +344,19 @@ def save_inference_model(dirname,
342
344
prepend_feed_ops (inference_program , feeded_var_names )
343
345
append_fetch_ops (inference_program , fetch_var_names )
344
346
345
- if save_file_name == None :
346
- model_file_name = dirname + "/__model__"
347
+ if model_filename is not None :
348
+ model_filename = os . path . basename ( model_filename )
347
349
else :
348
- model_file_name = dirname + "/__model_combined__"
350
+ model_filename = "__model__"
351
+ model_filename = os .path .join (dirname , model_filename )
349
352
350
- with open (model_file_name , "wb" ) as f :
353
+ if params_filename is not None :
354
+ params_filename = os .path .basename (params_filename )
355
+
356
+ with open (model_filename , "wb" ) as f :
351
357
f .write (inference_program .desc .serialize_to_string ())
352
358
353
- save_persistables (executor , dirname , inference_program , save_file_name )
359
+ save_persistables (executor , dirname , inference_program , params_filename )
354
360
355
361
356
362
def get_feed_targets_names (program ):
@@ -371,15 +377,21 @@ def get_fetch_targets_names(program):
371
377
return fetch_targets_names
372
378
373
379
374
- def load_inference_model (dirname , executor , load_file_name = None ):
380
+ def load_inference_model (dirname ,
381
+ executor ,
382
+ model_filename = None ,
383
+ params_filename = None ):
375
384
"""
376
385
Load inference model from a directory
377
386
378
387
:param dirname: directory path
379
388
:param executor: executor that load inference model
380
- :param load_file_name: The name of the single file that all parameters are loaded from.
381
- If it is None, load parameters from separate files.
382
-
389
+ :param model_filename: The name of file to load inference program.
390
+ If not specified, default filename `__model__` will be used.
391
+ :param params_filename: The name of file to load parameters.
392
+ It is used for the case that all parameters are saved in a single binary file.
393
+ If not specified, parameters are considered saved in separate files.
394
+
383
395
:return: [program, feed_target_names, fetch_targets]
384
396
program: program especially for inference.
385
397
feed_target_names: Names of variables that need to feed data
@@ -388,16 +400,20 @@ def load_inference_model(dirname, executor, load_file_name=None):
388
400
if not os .path .isdir (dirname ):
389
401
raise ValueError ("There is no directory named '%s'" , dirname )
390
402
391
- if load_file_name == None :
392
- model_file_name = dirname + "/__model__"
403
+ if model_filename is not None :
404
+ model_filename = os . path . basename ( model_filename )
393
405
else :
394
- model_file_name = dirname + "/__model_combined__"
406
+ model_filename = "__model__"
407
+ model_filename = os .path .join (dirname , model_filename )
408
+
409
+ if params_filename is not None :
410
+ params_filename = os .path .basename (params_filename )
395
411
396
- with open (model_file_name , "rb" ) as f :
412
+ with open (model_filename , "rb" ) as f :
397
413
program_desc_str = f .read ()
398
414
399
415
program = Program .parse_from_string (program_desc_str )
400
- load_persistables (executor , dirname , program , load_file_name )
416
+ load_persistables (executor , dirname , program , params_filename )
401
417
402
418
feed_target_names = get_feed_targets_names (program )
403
419
fetch_target_names = get_fetch_targets_names (program )
0 commit comments