@@ -188,7 +188,7 @@ def __init__(self,
188
188
if not self .checkpoint .is_pserver :
189
189
epoch_id , step_id = io .load_trainer_args (
190
190
self .checkpoint .checkpoint_dir , self .checkpoint .load_serial ,
191
- self .trainer_id , [ "epoch_id" , "step_id" ] )
191
+ self .trainer_id , self . _get_checkpoint_load_args () )
192
192
self .checkpoint .epoch_id = int (epoch_id )
193
193
self .checkpoint .step_id = int (step_id )
194
194
@@ -432,22 +432,33 @@ def _clean_checkpoint(self):
432
432
return
433
433
io .clean_checkpoint (checkpoint_dir = self .checkpoint .checkpoint_dir )
434
434
435
+ def _get_checkpoint_load_args (self ):
436
+ """
437
+ epoch_id and step_id are runtime arguments, they are not variables, will load them independently.
438
+ """
439
+ return ["epoch_id" , "step_id" ]
440
+
441
+ def _get_checkpoint_save_args (self , epoch_id , step_id ):
442
+ """
443
+ epoch_id and step_id are runtime arguments, they are not variables, will save them independently.
444
+ """
445
+ trainer_args = {}
446
+ trainer_args ["epoch_id" ] = epoch_id
447
+ trainer_args ["step_id" ] = step_id
448
+ return trainer_args
449
+
435
450
def _save_checkpoint (self , epoch_id , step_id ):
436
451
if not self .checkpoint :
437
452
return
438
453
439
454
if epoch_id % self .checkpoint .epoch_interval == 0 and step_id % self .checkpoint .step_interval == 0 :
440
- trainer_args = {}
441
- trainer_args ["epoch_id" ] = epoch_id
442
- trainer_args ["step_id" ] = step_id
443
-
444
455
exe = executor .Executor (self .place )
445
456
io .save_checkpoint (
446
457
executor = exe ,
447
458
checkpoint_dir = self .checkpoint .checkpoint_dir ,
448
459
trainer_id = self .trainer_id ,
449
460
is_chief = self .chief ,
450
- trainer_args = trainer_args ,
461
+ trainer_args = self . _get_checkpoint_save_args ( epoch_id , step_id ) ,
451
462
main_program = self .train_program ,
452
463
max_num_checkpoints = self .checkpoint .max_num_checkpoints )
453
464
0 commit comments