Skip to content

Commit 6483691

Browse files
committed
add tuning code
1 parent 8374002 commit 6483691

File tree

8 files changed

+66
-8
lines changed

8 files changed

+66
-8
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ This is the official implementation of the paper "[DETRs Beat YOLOs on Real-time
3636

3737
## Updates!!!
3838
---
39+
- \[2023.10.12\] Add tuning code for pytorch version, now you can tuning rtdetr based on pretrained weights
3940
- \[2023.09.19\] Upload [*pytorch weights*](https://github.com/lyuwenyu/RT-DETR/issues/42) convert from paddle version
4041
- \[2023.08.24] Release rtdetr-18 pretrained models on objects365. *49.2 mAP* and *217 FPS*
4142
- \[2023.08.22\] Upload *[rtdetr_pytorch](./rtdetr_pytorch/)* source code. Please enjoy it ❤️

rtdetr_pytorch/README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
- [x] Upload source code
66
- [x] Upload weight convert from paddle, see [links](https://github.com/lyuwenyu/RT-DETR/issues/42)
77
- [x] Align training details with the [paddle version](../rtdetr_paddle/)
8-
8+
- [x] Tuning rtdetr based on [pretrained weights](https://github.com/lyuwenyu/RT-DETR/issues/42)
99

1010
## Quick start
1111

@@ -79,6 +79,7 @@ python tools/export_onnx.py -c configs/rtdetr/rtdetr_r18vd_6x_coco.yml -r path/t
7979
<details open>
8080
<summary>Train custom data</summary>
8181

82-
set `remap_mscoco_category: False`. This variable only works for ms-coco dataset.
82+
1. set `remap_mscoco_category: False`. This variable only works for ms-coco dataset.
8383

84+
2. add `-t path/to/checkpoint` (optinal) to tuning rtdetr based on pretrained checkpoint. see [training script details](./tools/README.md).
8485
</details>

rtdetr_pytorch/src/core/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(self) -> None:
5050

5151
# runtime
5252
self.resume :str = None
53+
self.tuning :str = None
5354

5455
self.epoches :int = None
5556
self.last_epoch :int = -1

rtdetr_pytorch/src/core/yaml_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(self, cfg_path: str, **kwargs) -> None:
2626
self.checkpoint_step = cfg.get('checkpoint_step', 1)
2727
self.epoches = cfg.get('epoches', -1)
2828
self.resume = cfg.get('resume', '')
29+
self.tuning = cfg.get('tuning', '')
2930
self.sync_bn = cfg.get('sync_bn', False)
3031
self.output_dir = cfg.get('output_dir', None)
3132

rtdetr_pytorch/src/misc/dist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from torch.utils.data.dataloader import DataLoader
2121

2222

23-
def init_distributed(backend='nccl'):
23+
def init_distributed():
2424
'''
2525
distributed setup
2626
args:
@@ -32,7 +32,7 @@ def init_distributed(backend='nccl'):
3232
# RANK = int(os.getenv('RANK', -1))
3333
# WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
3434

35-
tdist.init_process_group(backend=backend, init_method='env://', )
35+
tdist.init_process_group(init_method='env://', )
3636
torch.distributed.barrier()
3737

3838
rank = get_rank()

rtdetr_pytorch/src/solver/solver.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from datetime import datetime
88
from pathlib import Path
9+
from typing import Dict
910

1011
from src.misc import dist
1112
from src.core import BaseConfig
@@ -28,6 +29,11 @@ def setup(self, ):
2829
self.criterion = cfg.criterion.to(device)
2930
self.postprocessor = cfg.postprocessor
3031

32+
# NOTE (lvwenyu): should load_tuning_state before ema instance building
33+
if self.cfg.tuning:
34+
print(f'Tuning checkpoint from {self.cfg.tuning}')
35+
self.load_tuning_state(self.cfg.tuning)
36+
3137
self.scaler = cfg.scaler
3238
self.ema = cfg.ema.to(device) if cfg.ema is not None else None
3339

@@ -133,10 +139,44 @@ def resume(self, path):
133139
state = torch.load(path, map_location='cpu')
134140
self.load_state_dict(state)
135141

142+
def load_tuning_state(self, path,):
143+
"""only load model for tuning and skip missed/dismatched keys
144+
"""
145+
if 'http' in path:
146+
state = torch.hub.load_state_dict_from_url(path, map_location='cpu')
147+
else:
148+
state = torch.load(path, map_location='cpu')
149+
150+
module = dist.de_parallel(self.model)
151+
152+
# TODO hard code
153+
if 'ema' in state:
154+
stat, infos = self._matched_state(module.state_dict(), state['ema']['module'])
155+
else:
156+
stat, infos = self._matched_state(module.state_dict(), state['model'])
157+
158+
module.load_state_dict(stat, strict=False)
159+
print(f'Load model.state_dict, {infos}')
160+
161+
@staticmethod
162+
def _matched_state(state: Dict[str, torch.Tensor], params: Dict[str, torch.Tensor]):
163+
missed_list = []
164+
unmatched_list = []
165+
matched_state = {}
166+
for k, v in state.items():
167+
if k in params:
168+
if v.shape == params[k].shape:
169+
matched_state[k] = params[k]
170+
else:
171+
unmatched_list.append(k)
172+
else:
173+
missed_list.append(k)
174+
175+
return matched_state, {'missed': missed_list, 'unmatched': unmatched_list}
176+
136177

137178
def fit(self, ):
138179
raise NotImplementedError('')
139180

140-
141181
def val(self, ):
142182
raise NotImplementedError('')

rtdetr_pytorch/tools/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ Train/test script examples
77
- `--test-only`
88

99

10+
Tuning script examples
11+
- `torchrun --master_port=8844 --nproc_per_node=4 tools/train.py -c configs/rtdetr/rtdetr_r18vd_6x_coco.yml -t https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r18vd_5x_coco_objects365_from_paddle.pth`
12+
13+
1014
Export script examples
1115
- `python tools/export_onnx.py -c path/to/config -r path/to/checkpoint --check`
1216

rtdetr_pytorch/tools/train.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,18 @@
1414
def main(args, ) -> None:
1515
'''main
1616
'''
17-
dist.init_distributed(backend='nccl')
18-
cfg = YAMLConfig(args.config, resume=args.resume, use_amp=args.amp)
17+
dist.init_distributed()
18+
19+
assert not all([args.tuning, args.resume]), \
20+
'Only support from_scrach or resume or tuning at one time'
21+
22+
cfg = YAMLConfig(
23+
args.config,
24+
resume=args.resume,
25+
use_amp=args.amp,
26+
tuning=args.tuning
27+
)
28+
1929
solver = TASKS[cfg.yaml_cfg['task']](cfg)
2030

2131
if args.test_only:
@@ -24,12 +34,12 @@ def main(args, ) -> None:
2434
solver.fit()
2535

2636

27-
2837
if __name__ == '__main__':
2938

3039
parser = argparse.ArgumentParser()
3140
parser.add_argument('--config', '-c', type=str, )
3241
parser.add_argument('--resume', '-r', type=str, )
42+
parser.add_argument('--tuning', '-t', type=str, )
3343
parser.add_argument('--test-only', action='store_true', default=False,)
3444
parser.add_argument('--amp', action='store_true', default=False,)
3545

0 commit comments

Comments
 (0)