1010from mmcv .runner .dist_utils import master_only
1111from mmcv .runner .hooks import CheckpointHook as _CheckpointHook
1212from ray .tune .integration .torch import distributed_checkpoint_dir
13+ from torch .optim import Optimizer
1314
1415
1516@HOOKS .register_module ()
@@ -85,8 +86,13 @@ def _save_checkpoint(self, runner: BaseRunner) -> None:
8586 The runner to save checkpoints.
8687 """
8788 model = runner .model
89+ optimizer = runner .optimizer
8890
89- meta = dict (mmcv_version = mmcv .__version__ , time = time .asctime ())
91+ meta = dict (
92+ mmcv_version = mmcv .__version__ ,
93+ time = time .asctime (),
94+ epoch = runner .epoch + 1 ,
95+ iter = runner .iter )
9096 if is_module_wrapper (model ):
9197 model = model .module
9298 if hasattr (model , 'CLASSES' ) and model .CLASSES is not None :
@@ -97,6 +103,13 @@ def _save_checkpoint(self, runner: BaseRunner) -> None:
97103 'state_dict' : weights_to_cpu (get_state_dict (model ))
98104 }
99105
106+ if isinstance (optimizer , Optimizer ):
107+ checkpoint ['optimizer' ] = optimizer .state_dict ()
108+ elif isinstance (optimizer , dict ):
109+ checkpoint ['optimizer' ] = {}
110+ for name , optim in optimizer .items ():
111+ checkpoint ['optimizer' ][name ] = optim .state_dict ()
112+
100113 with distributed_checkpoint_dir (
101114 step = self .get_iter (runner )) as checkpoint_dir :
102115 path = os .path .join (checkpoint_dir , 'ray_checkpoint.pth' )
0 commit comments