Skip to content

Commit c5dae95

Browse files
authored
Add load rewriter (#43)
* Add load rewriter * Fix resume * Fix typo * Fix test code
1 parent 4d64042 commit c5dae95

File tree

6 files changed

+60
-2
lines changed

6 files changed

+60
-2
lines changed

configs/mmtune/_base_/context/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
src_key='searched_cfg',
1313
dst_key='base_cfg',
1414
key='cfg'),
15+
dict(type='ResumeFromCkpt'),
1516
dict(
1617
type='CustomHookRegister',
1718
key='cfg',

mmtune/mm/context/rewriters/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
from .patch import BatchConfigPatcher, SequeunceConfigPatcher
77
from .path import AppendTrialIDtoPath
88
from .register import CustomHookRegister
9+
from .resume import ResumeFromCkpt
910

1011
__all__ = [
1112
'BaseRewriter', 'REWRITERS', 'build_rewriter', 'Dump', 'MergeConfig',
1213
'AppendTrialIDtoPath', 'BatchConfigPatcher', 'SequeunceConfigPatcher',
13-
'CustomHookRegister', 'InstantiateCfg'
14+
'CustomHookRegister', 'InstantiateCfg', 'ResumeFromCkpt'
1415
]
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from typing import Dict
2+
3+
from .base import BaseRewriter
4+
from .builder import REWRITERS
5+
6+
7+
@REWRITERS.register_module()
8+
class ResumeFromCkpt(BaseRewriter):
9+
"""Specifies the checkpoint for resuming training."""
10+
11+
def __init__(self, arg_name: str = 'resume_from') -> None:
12+
"""Initialize the rewriter.
13+
14+
Args:
15+
key (str): The key where the instantiated cfg is stored.
16+
arg_name (str): The key in the argparse namespace.
17+
"""
18+
self.arg_name = arg_name
19+
20+
def __call__(self, context: Dict) -> Dict:
21+
"""Set with checkpoints specified by Ray.
22+
23+
Args:
24+
context (Dict): The context to be rewritten.
25+
Returns:
26+
Dict: The context after rewriting.
27+
"""
28+
setattr(
29+
context.get('args'), self.arg_name, context.pop('checkpoint_dir'))
30+
return context

mmtune/mm/hooks/checkpoint.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from mmcv.runner.dist_utils import master_only
1111
from mmcv.runner.hooks import CheckpointHook as _CheckpointHook
1212
from ray.tune.integration.torch import distributed_checkpoint_dir
13+
from torch.optim import Optimizer
1314

1415

1516
@HOOKS.register_module()
@@ -85,8 +86,13 @@ def _save_checkpoint(self, runner: BaseRunner) -> None:
8586
The runner to save checkpoints.
8687
"""
8788
model = runner.model
89+
optimizer = runner.optimizer
8890

89-
meta = dict(mmcv_version=mmcv.__version__, time=time.asctime())
91+
meta = dict(
92+
mmcv_version=mmcv.__version__,
93+
time=time.asctime(),
94+
epoch=runner.epoch + 1,
95+
iter=runner.iter)
9096
if is_module_wrapper(model):
9197
model = model.module
9298
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
@@ -97,6 +103,13 @@ def _save_checkpoint(self, runner: BaseRunner) -> None:
97103
'state_dict': weights_to_cpu(get_state_dict(model))
98104
}
99105

106+
if isinstance(optimizer, Optimizer):
107+
checkpoint['optimizer'] = optimizer.state_dict()
108+
elif isinstance(optimizer, dict):
109+
checkpoint['optimizer'] = {}
110+
for name, optim in optimizer.items():
111+
checkpoint['optimizer'][name] = optim.state_dict()
112+
100113
with distributed_checkpoint_dir(
101114
step=self.get_iter(runner)) as checkpoint_dir:
102115
path = os.path.join(checkpoint_dir, 'ray_checkpoint.pth')

tests/test_mm/test_hooks.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@ def test_raycheckpointhook():
2121
mock_runner = MagicMock()
2222
mock_runner.inner_iter = 3
2323
mock_runner.iter = 5
24+
mock_runner.epoch = 5
2425

2526
cur_iter = hook.get_iter(mock_runner, False)
2627
assert cur_iter == 6
2728
cur_iter = hook.get_iter(mock_runner, True)
2829
assert cur_iter == 4
2930

3031
mock_runner.model = torch.nn.Linear(2, 2)
32+
mock_runner.optimizer = torch.optim.Adam(mock_runner.model.parameters())
3133

3234
hook._save_checkpoint(mock_runner)
3335
assert os.path.exists('ray_checkpoint.pth')

tests/test_mm/test_rewriters.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
BaseRewriter, BatchConfigPatcher,
1010
CustomHookRegister, Dump,
1111
InstantiateCfg, MergeConfig,
12+
ResumeFromCkpt,
1213
SequeunceConfigPatcher)
1314
from mmtune.mm.context.rewriters.builder import build_rewriter
1415
from mmtune.utils import dump_cfg
@@ -116,5 +117,15 @@ def test_register():
116117
cfg = MagicMock()
117118
cfg.custom_hooks = []
118119
context = dict(cfg=cfg)
120+
119121
context = register(context)
120122
assert context['cfg'].custom_hooks == post_custom_hooks
123+
124+
125+
def test_resume_ckpt():
126+
args = MagicMock()
127+
context = dict(args=args, checkpoint_dir='test')
128+
129+
resume_from_ckpt = ResumeFromCkpt()
130+
context = resume_from_ckpt(context)
131+
assert context.get('args').resume_from == 'test'

0 commit comments

Comments
 (0)