@@ -100,6 +100,7 @@ def _executor_dataset_train(self, model_dict, context):
100100 fetch_period = int (
101101 envs .get_global_env ("runner." + context ["runner_name" ] +
102102 ".print_interval" , 20 ))
103+
103104 scope = context ["model" ][model_name ]["scope" ]
104105 program = context ["model" ][model_name ]["main_program" ]
105106 reader = context ["dataset" ][reader_name ]
@@ -139,6 +140,9 @@ def _executor_dataloader_train(self, model_dict, context):
139140 fetch_period = int (
140141 envs .get_global_env ("runner." + context ["runner_name" ] +
141142 ".print_interval" , 20 ))
143+ save_step_interval = int (
144+ envs .get_global_env ("runner." + context ["runner_name" ] +
145+ ".save_step_interval" , - 1 ))
142146 if context ["is_infer" ]:
143147 metrics = model_class .get_infer_results ()
144148 else :
@@ -202,6 +206,24 @@ def _executor_dataloader_train(self, model_dict, context):
202206 metrics_logging .insert (1 , seconds )
203207 begin_time = end_time
204208 logging .info (metrics_format .format (* metrics_logging ))
209+
210+ if save_step_interval >= 1 and batch_id % save_step_interval == 0 and context [
211+ "is_infer" ] == False :
212+ if context ["fleet_mode" ].upper () == "PS" :
213+ train_prog = context ["model" ][model_dict ["name" ]][
214+ "main_program" ]
215+ else :
216+ train_prog = context ["model" ][model_dict ["name" ]][
217+ "default_main_program" ]
218+ startup_prog = context ["model" ][model_dict ["name" ]][
219+ "startup_program" ]
220+ with fluid .program_guard (train_prog , startup_prog ):
221+ self .save (
222+ context ,
223+ is_fleet = context ["is_fleet" ],
224+ epoch_id = None ,
225+ batch_id = batch_id )
226+
205227 batch_id += 1
206228 except fluid .core .EOFException :
207229 reader .reset ()
@@ -314,7 +336,7 @@ def _get_ps_program(self, model_dict, context):
314336 exec_strategy = _exe_strategy )
315337 return program
316338
317- def save (self , epoch_id , context , is_fleet = False ):
339+ def save (self , context , is_fleet = False , epoch_id = None , batch_id = None ):
318340 def need_save (epoch_id , epoch_interval , is_last = False ):
319341 name = "runner." + context ["runner_name" ] + "."
320342 total_epoch = int (envs .get_global_env (name + "epochs" , 1 ))
@@ -371,7 +393,8 @@ def save_inference_model():
371393
372394 assert dirname is not None
373395 dirname = os .path .join (dirname , str (epoch_id ))
374-
396+ logging .info ("\t save epoch_id:%d model into: \" %s\" " %
397+ (epoch_id , dirname ))
375398 if is_fleet :
376399 warnings .warn (
377400 "Save inference model in cluster training is not recommended! Using save checkpoint instead." ,
@@ -394,14 +417,35 @@ def save_persistables():
394417 if dirname is None or dirname == "" :
395418 return
396419 dirname = os .path .join (dirname , str (epoch_id ))
420+ logging .info ("\t save epoch_id:%d model into: \" %s\" " %
421+ (epoch_id , dirname ))
422+ if is_fleet :
423+ if context ["fleet" ].worker_index () == 0 :
424+ context ["fleet" ].save_persistables (context ["exe" ], dirname )
425+ else :
426+ fluid .io .save_persistables (context ["exe" ], dirname )
427+
428+ def save_checkpoint_step ():
429+ name = "runner." + context ["runner_name" ] + "."
430+ save_interval = int (
431+ envs .get_global_env (name + "save_step_interval" , - 1 ))
432+ dirname = envs .get_global_env (name + "save_step_path" , None )
433+ if dirname is None or dirname == "" :
434+ return
435+ dirname = os .path .join (dirname , str (batch_id ))
436+ logging .info ("\t save batch_id:%d model into: \" %s\" " %
437+ (batch_id , dirname ))
397438 if is_fleet :
398439 if context ["fleet" ].worker_index () == 0 :
399440 context ["fleet" ].save_persistables (context ["exe" ], dirname )
400441 else :
401442 fluid .io .save_persistables (context ["exe" ], dirname )
402443
403- save_persistables ()
404- save_inference_model ()
444+ if isinstance (epoch_id , int ):
445+ save_persistables ()
446+ save_inference_model ()
447+ if isinstance (batch_id , int ):
448+ save_checkpoint_step ()
405449
406450
407451class SingleRunner (RunnerBase ):
@@ -453,7 +497,7 @@ def run(self, context):
453497 startup_prog = context ["model" ][model_dict ["name" ]][
454498 "startup_program" ]
455499 with fluid .program_guard (train_prog , startup_prog ):
456- self .save (epoch , context )
500+ self .save (context = context , epoch_id = epoch )
457501 context ["status" ] = "terminal_pass"
458502
459503
@@ -506,7 +550,7 @@ def run(self, context):
506550 startup_prog = context ["model" ][model_dict ["name" ]][
507551 "startup_program" ]
508552 with fluid .program_guard (train_prog , startup_prog ):
509- self .save (epoch , context , True )
553+ self .save (context = context , is_fleet = True , epoch_id = epoch )
510554 context ["status" ] = "terminal_pass"
511555
512556
@@ -539,7 +583,7 @@ def run(self, context):
539583 startup_prog = context ["model" ][model_dict ["name" ]][
540584 "startup_program" ]
541585 with fluid .program_guard (train_prog , startup_prog ):
542- self .save (epoch , context , True )
586+ self .save (context = context , is_fleet = True , epoch_id = epoch )
543587 context ["status" ] = "terminal_pass"
544588
545589
0 commit comments