@@ -50,6 +50,7 @@ def _executor_dataset_train(self, model_dict, context):
5050 reader_name = model_dict ["dataset_name" ]
5151 model_name = model_dict ["name" ]
5252 model_class = context ["model" ][model_dict ["name" ]]["model" ]
53+
5354 fetch_vars = []
5455 fetch_alias = []
5556 fetch_period = int (
@@ -89,19 +90,7 @@ def _executor_dataset_train(self, model_dict, context):
8990 def _executor_dataloader_train (self , model_dict , context ):
9091 model_name = model_dict ["name" ]
9192 model_class = context ["model" ][model_dict ["name" ]]["model" ]
92-
93- if context ["is_infer" ]:
94- program = context ["model" ][model_name ]["main_program" ]
95- elif context ["is_fleet" ]:
96- if context ["fleet_mode" ].upper () == "PS" :
97- program = self ._get_ps_program (model_dict , context )
98- elif context ["fleet_mode" ].upper () == "COLLECTIVE" :
99- program = context ["model" ][model_name ]["main_program" ]
100- elif not context ["is_fleet" ]:
101- if context ["device" ].upper () == "CPU" :
102- program = self ._get_single_cpu_program (model_dict , context )
103- elif context ["device" ].upper () == "GPU" :
104- program = self ._get_single_gpu_program (model_dict , context )
93+ program = self ._get_dataloader_program (model_dict , context )
10594
10695 reader_name = model_dict ["dataset_name" ]
10796 fetch_vars = []
@@ -143,6 +132,24 @@ def _executor_dataloader_train(self, model_dict, context):
143132 except fluid .core .EOFException :
144133 reader .reset ()
145134
135+ def _get_dataloader_program (self , model_dict , context ):
136+ model_name = model_dict ["name" ]
137+ if context ["model" ][model_name ]["compiled_program" ] == None :
138+ if context ["is_infer" ]:
139+ program = context ["model" ][model_name ]["main_program" ]
140+ elif context ["is_fleet" ]:
141+ if context ["fleet_mode" ].upper () == "PS" :
142+ program = self ._get_ps_program (model_dict , context )
143+ elif context ["fleet_mode" ].upper () == "COLLECTIVE" :
144+ program = context ["model" ][model_name ]["main_program" ]
145+ elif not context ["is_fleet" ]:
146+ if context ["device" ].upper () == "CPU" :
147+ program = self ._get_single_cpu_program (model_dict , context )
148+ elif context ["device" ].upper () == "GPU" :
149+ program = self ._get_single_gpu_program (model_dict , context )
150+ context ["model" ][model_name ]["compiled_program" ] = program
151+ return context ["model" ][model_name ]["compiled_program" ]
152+
146153 def _get_strategy (self , model_dict , context ):
147154 _build_strategy = fluid .BuildStrategy ()
148155 _exe_strategy = fluid .ExecutionStrategy ()
@@ -218,12 +225,17 @@ def _get_ps_program(self, model_dict, context):
218225
219226 def save (self , epoch_id , context , is_fleet = False ):
220227 def need_save (epoch_id , epoch_interval , is_last = False ):
228+ name = "runner." + context ["runner_name" ] + "."
229+ total_epoch = int (envs .get_global_env (name + "epochs" , 1 ))
230+ if epoch_id + 1 == total_epoch :
231+ is_last = True
232+
221233 if is_last :
222234 return True
223235 if epoch_id == - 1 :
224236 return False
225237
226- return epoch_id % epoch_interval == 0
238+ return ( epoch_id + 1 ) % epoch_interval == 0
227239
228240 def save_inference_model ():
229241 name = "runner." + context ["runner_name" ] + "."
@@ -415,3 +427,53 @@ def run(self, context):
415427
416428 """
417429 context ["status" ] = "terminal_pass"
430+
431+
432+ class SingleInferRunner (RunnerBase ):
433+ def __init__ (self , context ):
434+ print ("Running SingleInferRunner." )
435+ pass
436+
437+ def run (self , context ):
438+ self ._dir_check (context )
439+
440+ for index , epoch_name in enumerate (self .epoch_model_name_list ):
441+ for model_dict in context ["phases" ]:
442+ self ._load (context , model_dict ,
443+ self .epoch_model_path_list [index ])
444+ begin_time = time .time ()
445+ self ._run (context , model_dict )
446+ end_time = time .time ()
447+ seconds = end_time - begin_time
448+ print ("Infer {} of {} done, use time: {}" .format (model_dict [
449+ "name" ], epoch_name , seconds ))
450+ context ["status" ] = "terminal_pass"
451+
452+ def _load (self , context , model_dict , model_path ):
453+ if model_path is None or model_path == "" :
454+ return
455+ print ("load persistables from" , model_path )
456+
457+ with fluid .scope_guard (context ["model" ][model_dict ["name" ]]["scope" ]):
458+ train_prog = context ["model" ][model_dict ["name" ]]["main_program" ]
459+ startup_prog = context ["model" ][model_dict ["name" ]][
460+ "startup_program" ]
461+ with fluid .program_guard (train_prog , startup_prog ):
462+ fluid .io .load_persistables (
463+ context ["exe" ], model_path , main_program = train_prog )
464+
465+ def _dir_check (self , context ):
466+ dirname = envs .get_global_env (
467+ "runner." + context ["runner_name" ] + ".init_model_path" , None )
468+ self .epoch_model_path_list = []
469+ self .epoch_model_name_list = []
470+
471+ for file in os .listdir (dirname ):
472+ file_path = os .path .join (dirname , file )
473+ if os .path .isdir (file_path ):
474+ self .epoch_model_path_list .append (file_path )
475+ self .epoch_model_name_list .append (file )
476+
477+ if len (self .epoch_model_path_list ) == 0 :
478+ self .epoch_model_path_list .append (dirname )
479+ self .epoch_model_name_list .append (dirname )
0 commit comments