@@ -37,7 +37,7 @@ def __init__(self, base_path, model, model_opt, fields, optim,
3737 if keep_checkpoint > 0 :
3838 self .checkpoint_queue = deque ([], maxlen = keep_checkpoint )
3939
40- def save (self , step , moving_average = None ):
40+ def save (self , step , is_best , moving_average = None ):
4141 """Main entry point for model saver
4242
4343 It wraps the `_save` method with checks and apply `keep_checkpoint`
@@ -54,7 +54,7 @@ def save(self, step, moving_average=None):
5454 model_params_data .append (param .data )
5555 param .data = avg .data
5656
57- chkpt , chkpt_name = self ._save (step , save_model )
57+ chkpt , chkpt_name = self ._save (step , is_best , save_model )
5858 self .last_saved_step = step
5959
6060 if moving_average :
@@ -97,7 +97,7 @@ def _rm_checkpoint(self, name):
9797class ModelSaver (ModelSaverBase ):
9898 """Simple model saver to filesystem"""
9999
100- def _save (self , step , model ):
100+ def _save (self , step , is_best , model ):
101101 model_state_dict = model .state_dict ()
102102 model_state_dict = {k : v for k , v in model_state_dict .items ()
103103 if 'generator' not in k }
@@ -128,6 +128,10 @@ def _save(self, step, model):
128128 logger .info ("Saving checkpoint %s_step_%d.pt" % (self .base_path , step ))
129129 checkpoint_path = '%s_step_%d.pt' % (self .base_path , step )
130130 torch .save (checkpoint , checkpoint_path )
131+ if is_best :
132+ logger .info ("Obtained best checkpoint, saving..." )
133+ best_checkpoint_path = '%s_best.pt' % (self .base_path )
134+ torch .save (checkpoint , best_checkpoint_path )
131135 return checkpoint , checkpoint_path
132136
133137 def _rm_checkpoint (self , name ):
0 commit comments