@@ -100,6 +100,7 @@ def _executor_dataset_train(self, model_dict, context):
100100 fetch_period = int (
101101 envs .get_global_env ("runner." + context ["runner_name" ] +
102102 ".print_interval" , 20 ))
103+
103104 scope = context ["model" ][model_name ]["scope" ]
104105 program = context ["model" ][model_name ]["main_program" ]
105106 reader = context ["dataset" ][reader_name ]
@@ -139,6 +140,9 @@ def _executor_dataloader_train(self, model_dict, context):
139140 fetch_period = int (
140141 envs .get_global_env ("runner." + context ["runner_name" ] +
141142 ".print_interval" , 20 ))
143+ save_step_interval = int (
144+ envs .get_global_env ("runner." + context ["runner_name" ] +
145+ ".save_step_interval" , - 1 ))
142146 if context ["is_infer" ]:
143147 metrics = model_class .get_infer_results ()
144148 else :
@@ -202,6 +206,25 @@ def _executor_dataloader_train(self, model_dict, context):
202206 metrics_logging = metrics .insert (1 , seconds )
203207 begin_time = end_time
204208 logging .info (metrics_format .format (* metrics ))
209+
210+ if save_step_interval >= 1 and batch_id % save_step_interval == 0 and context [
211+ "is_infer" ] == False :
212+ if context ["fleet_mode" ]:
213+ if context ["fleet_mode" ].upper () == "PS" :
214+ train_prog = context ["model" ][model_dict [
215+ "name" ]]["main_program" ]
216+ elif not context ["is_fleet" ] or context [
217+ "fleet_mode" ].upper () == "COLLECTIVE" :
218+ train_prog = context ["model" ][model_dict ["name" ]][
219+ "default_main_program" ]
220+ startup_prog = context ["model" ][model_dict ["name" ]][
221+ "startup_program" ]
222+ with fluid .program_guard (train_prog , startup_prog ):
223+ self .save (
224+ context ,
225+ is_fleet = context ["is_fleet" ],
226+ epoch_id = None ,
227+ batch_id = batch_id )
205228 batch_id += 1
206229 except fluid .core .EOFException :
207230 reader .reset ()
@@ -314,7 +337,7 @@ def _get_ps_program(self, model_dict, context):
314337 exec_strategy = _exe_strategy )
315338 return program
316339
317- def save (self , epoch_id , context , is_fleet = False ):
340+ def save (self , context , is_fleet = False , epoch_id = None , batch_id = None ):
318341 def need_save (epoch_id , epoch_interval , is_last = False ):
319342 name = "runner." + context ["runner_name" ] + "."
320343 total_epoch = int (envs .get_global_env (name + "epochs" , 1 ))
@@ -371,7 +394,8 @@ def save_inference_model():
371394
372395 assert dirname is not None
373396 dirname = os .path .join (dirname , str (epoch_id ))
374-
397+ logging .info ("\t save epoch_id:%d model into: \" %s\" " %
398+ (epoch_id , dirname ))
375399 if is_fleet :
376400 warnings .warn (
377401 "Save inference model in cluster training is not recommended! Using save checkpoint instead." ,
@@ -394,14 +418,35 @@ def save_persistables():
394418 if dirname is None or dirname == "" :
395419 return
396420 dirname = os .path .join (dirname , str (epoch_id ))
421+ logging .info ("\t save epoch_id:%d model into: \" %s\" " %
422+ (epoch_id , dirname ))
423+ if is_fleet :
424+ if context ["fleet" ].worker_index () == 0 :
425+ context ["fleet" ].save_persistables (context ["exe" ], dirname )
426+ else :
427+ fluid .io .save_persistables (context ["exe" ], dirname )
428+
429+ def save_checkpoint_step ():
430+ name = "runner." + context ["runner_name" ] + "."
431+ save_interval = int (
432+ envs .get_global_env (name + "save_step_interval" , - 1 ))
433+ dirname = envs .get_global_env (name + "save_step_path" , None )
434+ if dirname is None or dirname == "" :
435+ return
436+ dirname = os .path .join (dirname , str (batch_id ))
437+ logging .info ("\t save batch_id:%d model into: \" %s\" " %
438+ (batch_id , dirname ))
397439 if is_fleet :
398440 if context ["fleet" ].worker_index () == 0 :
399441 context ["fleet" ].save_persistables (context ["exe" ], dirname )
400442 else :
401443 fluid .io .save_persistables (context ["exe" ], dirname )
402444
403- save_persistables ()
404- save_inference_model ()
445+ if isinstance (epoch_id , int ):
446+ save_persistables ()
447+ save_inference_model ()
448+ if isinstance (batch_id , int ):
449+ save_checkpoint_step ()
405450
406451
407452class SingleRunner (RunnerBase ):
@@ -453,7 +498,7 @@ def run(self, context):
453498 startup_prog = context ["model" ][model_dict ["name" ]][
454499 "startup_program" ]
455500 with fluid .program_guard (train_prog , startup_prog ):
456- self .save (epoch , context )
501+ self .save (context = context , epoch_id = epoch )
457502 context ["status" ] = "terminal_pass"
458503
459504
@@ -506,7 +551,7 @@ def run(self, context):
506551 startup_prog = context ["model" ][model_dict ["name" ]][
507552 "startup_program" ]
508553 with fluid .program_guard (train_prog , startup_prog ):
509- self .save (epoch , context , True )
554+ self .save (context = context , is_fleet = True , epoch_id = epoch )
510555 context ["status" ] = "terminal_pass"
511556
512557
@@ -539,7 +584,7 @@ def run(self, context):
539584 startup_prog = context ["model" ][model_dict ["name" ]][
540585 "startup_program" ]
541586 with fluid .program_guard (train_prog , startup_prog ):
542- self .save (epoch , context , True )
587+ self .save (context = context , is_fleet = True , epoch_id = epoch )
543588 context ["status" ] = "terminal_pass"
544589
545590
0 commit comments