18
18
import time
19
19
import logging
20
20
21
+ import paddle
21
22
from paddle .fluid import core
22
23
from paddle .fluid import io
23
24
from paddle .fluid import Program
@@ -84,8 +85,9 @@ def convert_dist_to_sparse_program(program):
84
85
when we train model with distributed lookup table but want to do the local inference, we can use
85
86
this function to convert the train program with distributed lookup table to sparse lookup table.
86
87
87
- :param program(Program): the program must be the trainer program, which will be get by the distribute transpiler.
88
- :return:
88
+ Args:
89
+ program(Program): the program must be the trainer program, which will be get by the distribute transpiler.
90
+ Returns:
89
91
program: The `program` is a Program, it's the program replace distributed lookup table to sparse lookup table.
90
92
"""
91
93
if not program ._distributed_lookup_table :
@@ -128,68 +130,92 @@ def convert_dist_to_sparse_program(program):
128
130
return program
129
131
130
132
131
- def _load_persistable_vars (executor , dirname , program , lookup_table_vars ):
132
- def _is_checkpoint_var (exclude_fluid_vars = None ):
133
- """
134
- the checkpoint will not save or load all the variables.
135
- var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
136
-
137
- : param var(Variable)
138
- """
139
-
140
- if exclude_fluid_vars is None :
141
- exclude_fluid_vars = []
142
-
143
- def is_valid (var ):
144
- if var .desc .type () == core .VarDesc .VarType .FEED_MINIBATCH or \
145
- var .desc .type () == core .VarDesc .VarType .FETCH_LIST or \
146
- var .desc .type () == core .VarDesc .VarType .RAW :
147
- return False
148
- # @GRAD are named for gradient variables, checkpoint will not save it.
149
- if "@GRAD" in var .name :
150
- return False
151
- # .trainer_ are named for distribute train variables, checkpoint will not save it.
152
- if ".trainer_" in var .name :
153
- return False
154
-
155
- # .block is named for distribute train variables, checkpoint will not save it.
156
- if ".block" in var .name :
157
- return False
158
-
159
- if "tmp_" in var .name :
160
- return False
161
-
162
- if var .name in exclude_fluid_vars :
163
- return False
164
-
165
- return var .persistable
166
-
167
- return is_valid
168
-
169
- io .load_vars (
170
- executor ,
171
- dirname = dirname ,
172
- main_program = program ,
173
- predicate = _is_checkpoint_var (lookup_table_vars ),
174
- filename = None )
175
-
176
-
177
133
def load_persistables_for_increment (dirname , executor , program ,
178
134
lookup_table_var , lookup_table_var_path ):
179
135
"""
180
136
WARNING: this function will only be used for distributed training with distributed lookup table.
181
137
for increment trainning, the pserver will not only load dense variables,
182
- but also load the suitable lookup table var. Because of slice lookup table
183
- var with HASH, we must load the correct slice var.
138
+ but also load the suitable lookup table var. Because of sliced lookup table
139
+ var with HASH, we must load the correct sliced var.
140
+
141
+ Args:
142
+ dirname(str): The directory path
143
+ executor(Executor): The executor to run for loading inference model.
144
+ program(Program): The parameter server program, which will run on Pserver.
145
+ lookup_table_var: the distributed lookup tables var name.
146
+ lookup_table_var_path: the the distributed lookup tables var location.
147
+
148
+ Returns:
149
+ None
150
+ """
184
151
152
+ def _load_persistable_vars (executor , dirname , need_load_vars ):
153
+ load_prog = Program ()
154
+ load_block = load_prog .global_block ()
155
+ need_delete_vars = []
156
+
157
+ for param in need_load_vars :
158
+ origin_var = param .origin
159
+ slice_var = param .slice
160
+ is_slice = param .is_slice
161
+ offset = param .offset
162
+
163
+ if is_slice :
164
+ origin = load_block .create_var (
165
+ name = "{}.load" .format (origin_var .name ),
166
+ type = origin_var .type ,
167
+ shape = origin_var .shape ,
168
+ dtype = origin_var .dtype ,
169
+ persistable = True )
170
+
171
+ load_block .append_op (
172
+ type = 'load' ,
173
+ inputs = {},
174
+ outputs = {'Out' : [origin ]},
175
+ attrs = {
176
+ 'file_path' : os .path .join (dirname , origin_var .name )
177
+ })
178
+
179
+ slice = load_block .create_var (
180
+ name = slice_var .name ,
181
+ type = slice_var .type ,
182
+ shape = slice_var .shape ,
183
+ dtype = slice_var .dtype ,
184
+ persistable = True )
185
+
186
+ dim1_flatten = reduce (lambda x , y : x * y , slice .shape [1 :])
187
+ start = int (offset / dim1_flatten )
188
+ end = int (offset / dim1_flatten + slice .shape [0 ])
189
+
190
+ load_block .append_op (
191
+ type = "slice" ,
192
+ inputs = {'Input' : origin },
193
+ outputs = {'Out' : slice },
194
+ attrs = {'axes' : [0 ],
195
+ 'starts' : [start ],
196
+ 'ends' : [end ]})
197
+
198
+ need_delete_vars .append (origin )
199
+ else :
200
+ origin = load_block .create_var (
201
+ name = "{}" .format (origin_var .name ),
202
+ type = origin_var .type ,
203
+ shape = origin_var .shape ,
204
+ dtype = origin_var .dtype ,
205
+ persistable = True )
206
+ load_block .append_op (
207
+ type = 'load' ,
208
+ inputs = {},
209
+ outputs = {'Out' : [origin ]},
210
+ attrs = {
211
+ 'file_path' : os .path .join (dirname , origin_var .name )
212
+ })
185
213
186
- :param dirname(str): The directory path
187
- :param executor(Executor): The executor to run for loading inference model.
188
- :param program(Program): The parameter server program, which will run on Pserver.
189
- :param lookup_table_var: the distributed lookup tables var name.
190
- :param lookup_table_var_path: the the distributed lookup tables var location.
191
- :return: None
192
- """
214
+ load_block .append_op (
215
+ type = 'delete_var' ,
216
+ inputs = {'X' : need_delete_vars }, )
217
+
218
+ executor .run (load_prog )
193
219
194
220
def __load_lookup_table_vars (executor , main_program , lookup_table_var ,
195
221
lookup_table_var_path ):
@@ -217,7 +243,9 @@ def __load_lookup_table_vars(executor, main_program, lookup_table_var,
217
243
"Distributed Lookup Table Vars from {}, time = {}" .format (
218
244
dirname , time .ctime ()))
219
245
220
- _load_persistable_vars (executor , dirname , program , [lookup_table_var ])
246
+ need_load_vars = program ._parameters_on_pservers .get_distributed_vars_by_ep (
247
+ program ._ps_endpoint )
248
+ _load_persistable_vars (executor , dirname , need_load_vars )
221
249
__load_lookup_table_vars (executor , program , lookup_table_var ,
222
250
lookup_table_var_path )
223
251
@@ -233,15 +261,62 @@ def load_persistables_for_inference(dirname, executor, program,
233
261
Inference with distributed lookup table is a little funky, this function will load distributed
234
262
lookup table vars into sparse var, can be used in local inference mode.
235
263
236
- :param dirname(str): The directory path
237
- :param executor(Executor): The executor to run for loading inference model.
238
- :param program(Program): The parameter server program, which will run on Pserver.
239
- :param lookup_table_var_name: the distributed lookup tables var name.
240
- :return: None
264
+ Args:
265
+ dirname(str): The directory path
266
+ executor(Executor): The executor to run for loading inference model.
267
+ program(Program): The parameter server program, which will run on Pserver.
268
+ lookup_table_var_name: the distributed lookup tables var name.
269
+ Returns:
270
+ None
241
271
"""
242
272
243
- def __load_lookup_table_vars (executor , dirname , main_program ,
244
- lookup_table_vars ):
273
+ def _load_persistable_vars (executor , dirname , program , lookup_table_vars ):
274
+ def _is_checkpoint_var (exclude_fluid_vars = None ):
275
+ """
276
+ the checkpoint will not save or load all the variables.
277
+ var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
278
+
279
+ : param var(Variable)
280
+ """
281
+
282
+ if exclude_fluid_vars is None :
283
+ exclude_fluid_vars = []
284
+
285
+ def is_valid (var ):
286
+ if var .desc .type () == core .VarDesc .VarType .FEED_MINIBATCH or \
287
+ var .desc .type () == core .VarDesc .VarType .FETCH_LIST or \
288
+ var .desc .type () == core .VarDesc .VarType .RAW :
289
+ return False
290
+ # @GRAD are named for gradient variables, checkpoint will not save it.
291
+ if "@GRAD" in var .name :
292
+ return False
293
+ # .trainer_ are named for distribute train variables, checkpoint will not save it.
294
+ if ".trainer_" in var .name :
295
+ return False
296
+
297
+ # .block is named for distribute train variables, checkpoint will not save it.
298
+ if ".block" in var .name :
299
+ return False
300
+
301
+ if "tmp_" in var .name :
302
+ return False
303
+
304
+ if var .name in exclude_fluid_vars :
305
+ return False
306
+
307
+ return var .persistable
308
+
309
+ return is_valid
310
+
311
+ io .load_vars (
312
+ executor ,
313
+ dirname = dirname ,
314
+ main_program = program ,
315
+ predicate = _is_checkpoint_var (lookup_table_vars ),
316
+ filename = None )
317
+
318
+ def _load_lookup_table_vars (executor , dirname , main_program ,
319
+ lookup_table_vars ):
245
320
if not os .path .isdir (dirname ):
246
321
raise ValueError ("There is no directory named '%s'" , dirname )
247
322
@@ -313,11 +388,96 @@ def __load_lookup_table_vars(executor, dirname, main_program,
313
388
dirname , time .ctime ()))
314
389
315
390
_load_persistable_vars (executor , dirname , program , [lookup_table_var_name ])
316
- __load_lookup_table_vars (executor , dirname , program ,
317
- [lookup_table_var_name ])
391
+ _load_lookup_table_vars (executor , dirname , program , [lookup_table_var_name ])
318
392
319
393
_logger .info ("Finish Load Sparse Program With "
320
394
"Distributed Lookup Table Vars from {}, time = {}" .format (
321
395
dirname , time .ctime ()))
322
396
323
397
return program
398
+
399
+
400
+ def get_inference_model (main_program , feeded_var_names , target_vars ):
401
+ """
402
+ Prune the given `main_program` to build a new program especially for inference with distributed lookup table ,
403
+ and then add `feeded_vars` and `target_vars` in this program.
404
+
405
+ Args:
406
+ main_program(Program|None): The original program, which will be pruned to
407
+ build the inference model. If is setted None,
408
+ the default main program will be used.
409
+ Default: None.
410
+ feeded_var_names(list[str]): Names of variables that need to be feeded data
411
+ during inference.
412
+ target_vars(list[Variable]): Variables from which we can get inference
413
+ results.
414
+ Returns:
415
+ program(Program)
416
+
417
+ Raises:
418
+ ValueError: If `feed_var_names` is not a list of basestring.
419
+ ValueError: If `target_vars` is not a list of Variable.
420
+
421
+ """
422
+
423
+ def prepend_feed_ops (inference_program ,
424
+ feed_target_names ,
425
+ feed_holder_name = 'feed' ):
426
+ if len (feed_target_names ) == 0 :
427
+ return
428
+
429
+ global_block = inference_program .global_block ()
430
+
431
+ feed_var = global_block .create_var (
432
+ name = feed_holder_name ,
433
+ type = core .VarDesc .VarType .FEED_MINIBATCH ,
434
+ persistable = True )
435
+
436
+ for i , name in enumerate (feed_target_names ):
437
+ out = global_block .var (name )
438
+ global_block ._prepend_op (
439
+ type = 'feed' ,
440
+ inputs = {'X' : [feed_var ]},
441
+ outputs = {'Out' : [out ]},
442
+ attrs = {'col' : i })
443
+
444
+ def append_fetch_ops (inference_program ,
445
+ fetch_target_names ,
446
+ fetch_holder_name = 'fetch' ):
447
+ global_block = inference_program .global_block ()
448
+ fetch_var = global_block .create_var (
449
+ name = fetch_holder_name ,
450
+ type = core .VarDesc .VarType .FETCH_LIST ,
451
+ persistable = True )
452
+
453
+ for i , name in enumerate (fetch_target_names ):
454
+ global_block .append_op (
455
+ type = 'fetch' ,
456
+ inputs = {'X' : [name ]},
457
+ outputs = {'Out' : [fetch_var ]},
458
+ attrs = {'col' : i })
459
+
460
+ origin_program = main_program .clone ()
461
+ main_program = main_program .clone ()
462
+ global_block = main_program .global_block ()
463
+
464
+ need_to_remove_op_index = []
465
+ for i , op in enumerate (global_block .ops ):
466
+ op .desc .set_is_target (False )
467
+ if op .type == "feed" or op .type == "fetch" :
468
+ need_to_remove_op_index .append (i )
469
+
470
+ for index in need_to_remove_op_index [::- 1 ]:
471
+ global_block ._remove_op (index )
472
+
473
+ main_program .desc .flush ()
474
+
475
+ main_program = main_program ._prune (targets = target_vars )
476
+ main_program = main_program ._inference_optimize (prune_read_op = True )
477
+
478
+ fetch_var_names = [v .name for v in target_vars ]
479
+
480
+ prepend_feed_ops (main_program , feeded_var_names )
481
+ append_fetch_ops (main_program , fetch_var_names )
482
+
483
+ return main_program
0 commit comments