Skip to content

Commit 24f2fdb

Browse files
fix lint (#7793)
1 parent c8a1873 commit 24f2fdb

File tree

6 files changed

+46
-41
lines changed

6 files changed

+46
-41
lines changed

mmdet/apis/train.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55
import numpy as np
66
import torch
77
import torch.distributed as dist
8-
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
98
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner,
109
Fp16OptimizerHook, OptimizerHook, build_optimizer,
1110
build_runner, get_dist_info)
1211

1312
from mmdet.core import DistEvalHook, EvalHook
1413
from mmdet.datasets import (build_dataloader, build_dataset,
1514
replace_ImageToTensor)
16-
from mmdet.utils import (compat_cfg, find_latest_checkpoint, get_root_logger,
17-
build_ddp, build_dp)
15+
from mmdet.utils import (build_ddp, build_dp, compat_cfg,
16+
find_latest_checkpoint, get_root_logger)
17+
1818

1919
def init_random_seed(seed=None, device='cuda'):
2020
"""Initialize random seed.
@@ -153,10 +153,12 @@ def train_detector(model,
153153
find_unused_parameters = cfg.get('find_unused_parameters', False)
154154
# Sets the `find_unused_parameters` parameter in
155155
# torch.nn.parallel.DistributedDataParallel
156-
model = build_ddp(model, cfg.device,
157-
device_ids=[int(os.environ['LOCAL_RANK'])],
158-
broadcast_buffers=False,
159-
find_unused_parameters=find_unused_parameters)
156+
model = build_ddp(
157+
model,
158+
cfg.device,
159+
device_ids=[int(os.environ['LOCAL_RANK'])],
160+
broadcast_buffers=False,
161+
find_unused_parameters=find_unused_parameters)
160162
else:
161163
model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)
162164

mmdet/datasets/samplers/distributed_sampler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from mmdet.core.utils import sync_random_seed
88
from mmdet.utils import get_device
99

10+
1011
class DistributedSampler(_DistributedSampler):
1112

1213
def __init__(self,
@@ -23,7 +24,7 @@ def __init__(self,
2324
# is used to make sure that each rank shuffles the data indices
2425
# in the same order based on the same seed. Then different ranks
2526
# could use different indices to select non-overlapped data from the
26-
# same data list.
27+
# same data list.
2728
device = get_device()
2829
self.seed = sync_random_seed(seed, device)
2930

mmdet/utils/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@
1010
__all__ = [
1111
'get_root_logger', 'collect_env', 'find_latest_checkpoint',
1212
'update_data_root', 'setup_multi_processes', 'get_caller_name',
13-
'log_img_scale', 'compat_cfg', 'split_batch', 'build_ddp',
14-
'build_dp', 'get_device'
13+
'log_img_scale', 'compat_cfg', 'split_batch', 'build_ddp', 'build_dp',
14+
'get_device'
1515
]

mmdet/utils/util_distribution.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,24 @@
22
import torch
33
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
44

5-
dp_factory = {
6-
'cuda' : MMDataParallel,
7-
'cpu' : MMDataParallel
8-
}
5+
dp_factory = {'cuda': MMDataParallel, 'cpu': MMDataParallel}
6+
7+
ddp_factory = {'cuda': MMDistributedDataParallel}
98

10-
ddp_factory = {
11-
'cuda' : MMDistributedDataParallel
12-
}
139

1410
def build_dp(model, device='cuda', dim=0, *args, **kwargs):
1511
"""build DataParallel module by device type.
16-
12+
1713
if device is cuda, return a MMDataParallel model; if device is mlu,
1814
return a MLUDataParallel model.
1915
2016
Args:
21-
model(:class:`nn.Module`): model to be parallelized.
22-
device(str): device type, cuda, cpu or mlu. Defaults to cuda.
23-
dim(int): Dimension used to scatter the data. Defaults to 0.
17+
model (:class:`nn.Module`): model to be parallelized.
18+
device (str): device type, cuda, cpu or mlu. Defaults to cuda.
19+
dim (int): Dimension used to scatter the data. Defaults to 0.
2420
2521
Returns:
26-
model(nn.Module): the model to be parallelized.
22+
nn.Module: the model to be parallelized.
2723
"""
2824
if device == 'cuda':
2925
model = model.cuda()
@@ -38,15 +34,15 @@ def build_dp(model, device='cuda', dim=0, *args, **kwargs):
3834
def build_ddp(model, device='cuda', *args, **kwargs):
3935
"""Build DistributedDataParallel module by device type.
4036
41-
If device is cuda, return a MMDistributedDataParallel model; if device is mlu,
42-
return a MLUDistributedDataParallel model.
37+
If device is cuda, return a MMDistributedDataParallel model;
38+
if device is mlu, return a MLUDistributedDataParallel model.
4339
4440
Args:
45-
model(:class:`nn.Module`): module to be parallelized.
46-
device(str): device type, mlu or cuda.
41+
model (:class:`nn.Module`): module to be parallelized.
42+
device (str): device type, mlu or cuda.
4743
4844
Returns:
49-
model(:class:`nn.Module`): the module to be parallelized
45+
:class:`nn.Module`: the module to be parallelized
5046
5147
References:
5248
.. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
@@ -56,19 +52,23 @@ def build_ddp(model, device='cuda', *args, **kwargs):
5652
if device == 'cuda':
5753
model = model.cuda()
5854
elif device == 'mlu':
59-
from mmcv.device.mlu import MLUDistributedDataParallel
55+
from mmcv.device.mlu import MLUDistributedDataParallel
6056
ddp_factory['mlu'] = MLUDistributedDataParallel
6157
model = model.mlu()
6258

6359
return ddp_factory[device](model, *args, **kwargs)
6460

61+
6562
def is_mlu_available():
66-
""" Returns a bool indicating if MLU is currently available. """
63+
"""Returns a bool indicating if MLU is currently available."""
6764
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
6865

66+
6967
def get_device():
70-
""" Returns an available device, cpu, cuda or mlu. """
71-
is_device_available = {'cuda': torch.cuda.is_available(),
72-
'mlu': is_mlu_available()}
73-
device_list = [k for k, v in is_device_available.items() if v ]
68+
"""Returns an available device, cpu, cuda or mlu."""
69+
is_device_available = {
70+
'cuda': torch.cuda.is_available(),
71+
'mlu': is_mlu_available()
72+
}
73+
device_list = [k for k, v in is_device_available.items() if v]
7474
return device_list[0] if len(device_list) == 1 else 'cpu'

tools/test.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@
99
import torch
1010
from mmcv import Config, DictAction
1111
from mmcv.cnn import fuse_conv_bn
12-
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
1312
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
1413
wrap_fp16_model)
1514

1615
from mmdet.apis import multi_gpu_test, single_gpu_test
1716
from mmdet.datasets import (build_dataloader, build_dataset,
1817
replace_ImageToTensor)
1918
from mmdet.models import build_detector
20-
from mmdet.utils import (compat_cfg, setup_multi_processes, update_data_root,
21-
build_ddp, build_dp, get_device)
19+
from mmdet.utils import (build_ddp, build_dp, compat_cfg, get_device,
20+
setup_multi_processes, update_data_root)
21+
2222

2323
def parse_args():
2424
parser = argparse.ArgumentParser(
@@ -234,9 +234,11 @@ def main():
234234
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
235235
args.show_score_thr)
236236
else:
237-
model = build_ddp(model, cfg.device,
238-
device_ids=[int(os.environ['LOCAL_RANK'])],
239-
broadcast_buffers=False)
237+
model = build_ddp(
238+
model,
239+
cfg.device,
240+
device_ids=[int(os.environ['LOCAL_RANK'])],
241+
broadcast_buffers=False)
240242
outputs = multi_gpu_test(model, data_loader, args.tmpdir,
241243
args.gpu_collect)
242244

tools/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from mmdet.apis import init_random_seed, set_random_seed, train_detector
1818
from mmdet.datasets import build_dataset
1919
from mmdet.models import build_detector
20-
from mmdet.utils import (collect_env, get_root_logger, setup_multi_processes,
21-
update_data_root, get_device)
20+
from mmdet.utils import (collect_env, get_device, get_root_logger,
21+
setup_multi_processes, update_data_root)
2222

2323

2424
def parse_args():

0 commit comments

Comments
 (0)