Skip to content

Commit b062468

Browse files
triple-MuHAOCHENYEzhouzaida
authored
Add type hints for mmcv/parallel (#2031)
* Add typehints * Fix * Fix * Update mmcv/parallel/distributed_deprecated.py Co-authored-by: Mashiro <[email protected]> * Fix * add type hints to scatter add type hints to scatter * fix ScatterInputs * Update mmcv/parallel/_functions.py Co-authored-by: Zaida Zhou <[email protected]> * Fix * refine type hints * minor fix Co-authored-by: Mashiro <[email protected]> Co-authored-by: HAOCHENYE <[email protected]> Co-authored-by: Zaida Zhou <[email protected]> Co-authored-by: zhouzaida <[email protected]>
1 parent 9110df9 commit b062468

File tree

8 files changed

+78
-47
lines changed

8 files changed

+78
-47
lines changed

mmcv/parallel/_functions.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import List, Optional, Union
3+
24
import torch
5+
from torch import Tensor
36
from torch.nn.parallel._functions import _get_stream
47

58

6-
def scatter(input, devices, streams=None):
9+
def scatter(input: Union[List, Tensor],
10+
devices: List,
11+
streams: Optional[List] = None) -> Union[List, Tensor]:
712
"""Scatters tensor across multiple GPUs."""
813
if streams is None:
914
streams = [None] * len(devices)
@@ -15,7 +20,7 @@ def scatter(input, devices, streams=None):
1520
[streams[i // chunk_size]]) for i in range(len(input))
1621
]
1722
return outputs
18-
elif isinstance(input, torch.Tensor):
23+
elif isinstance(input, Tensor):
1924
output = input.contiguous()
2025
# TODO: copy to a pinned buffer first (if copying from CPU)
2126
stream = streams[0] if output.numel() > 0 else None
@@ -28,14 +33,15 @@ def scatter(input, devices, streams=None):
2833
raise Exception(f'Unknown type {type(input)}.')
2934

3035

31-
def synchronize_stream(output, devices, streams):
36+
def synchronize_stream(output: Union[List, Tensor], devices: List,
37+
streams: List) -> None:
3238
if isinstance(output, list):
3339
chunk_size = len(output) // len(devices)
3440
for i in range(len(devices)):
3541
for j in range(chunk_size):
3642
synchronize_stream(output[i * chunk_size + j], [devices[i]],
3743
[streams[i]])
38-
elif isinstance(output, torch.Tensor):
44+
elif isinstance(output, Tensor):
3945
if output.numel() != 0:
4046
with torch.cuda.device(devices[0]):
4147
main_stream = torch.cuda.current_stream()
@@ -45,14 +51,14 @@ def synchronize_stream(output, devices, streams):
4551
raise Exception(f'Unknown type {type(output)}.')
4652

4753

48-
def get_input_device(input):
54+
def get_input_device(input: Union[List, Tensor]) -> int:
4955
if isinstance(input, list):
5056
for item in input:
5157
input_device = get_input_device(item)
5258
if input_device != -1:
5359
return input_device
5460
return -1
55-
elif isinstance(input, torch.Tensor):
61+
elif isinstance(input, Tensor):
5662
return input.get_device() if input.is_cuda else -1
5763
else:
5864
raise Exception(f'Unknown type {type(input)}.')
@@ -61,7 +67,7 @@ def get_input_device(input):
6167
class Scatter:
6268

6369
@staticmethod
64-
def forward(target_gpus, input):
70+
def forward(target_gpus: List[int], input: Union[List, Tensor]) -> tuple:
6571
input_device = get_input_device(input)
6672
streams = None
6773
if input_device == -1 and target_gpus != [-1]:

mmcv/parallel/collate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .data_container import DataContainer
99

1010

11-
def collate(batch, samples_per_gpu=1):
11+
def collate(batch: Sequence, samples_per_gpu: int = 1):
1212
"""Puts each data field into a tensor/DataContainer with outer dimension
1313
batch size.
1414

mmcv/parallel/data_container.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import functools
3+
from typing import Callable, Type, Union
34

5+
import numpy as np
46
import torch
57

68

7-
def assert_tensor_type(func):
9+
def assert_tensor_type(func: Callable) -> Callable:
810

911
@functools.wraps(func)
1012
def wrapper(*args, **kwargs):
@@ -35,55 +37,55 @@ class DataContainer:
3537
"""
3638

3739
def __init__(self,
38-
data,
39-
stack=False,
40-
padding_value=0,
41-
cpu_only=False,
42-
pad_dims=2):
40+
data: Union[torch.Tensor, np.ndarray],
41+
stack: bool = False,
42+
padding_value: int = 0,
43+
cpu_only: bool = False,
44+
pad_dims: int = 2):
4345
self._data = data
4446
self._cpu_only = cpu_only
4547
self._stack = stack
4648
self._padding_value = padding_value
4749
assert pad_dims in [None, 1, 2, 3]
4850
self._pad_dims = pad_dims
4951

50-
def __repr__(self):
52+
def __repr__(self) -> str:
5153
return f'{self.__class__.__name__}({repr(self.data)})'
5254

53-
def __len__(self):
55+
def __len__(self) -> int:
5456
return len(self._data)
5557

5658
@property
57-
def data(self):
59+
def data(self) -> Union[torch.Tensor, np.ndarray]:
5860
return self._data
5961

6062
@property
61-
def datatype(self):
63+
def datatype(self) -> Union[Type, str]:
6264
if isinstance(self.data, torch.Tensor):
6365
return self.data.type()
6466
else:
6567
return type(self.data)
6668

6769
@property
68-
def cpu_only(self):
70+
def cpu_only(self) -> bool:
6971
return self._cpu_only
7072

7173
@property
72-
def stack(self):
74+
def stack(self) -> bool:
7375
return self._stack
7476

7577
@property
76-
def padding_value(self):
78+
def padding_value(self) -> int:
7779
return self._padding_value
7880

7981
@property
80-
def pad_dims(self):
82+
def pad_dims(self) -> int:
8183
return self._pad_dims
8284

8385
@assert_tensor_type
84-
def size(self, *args, **kwargs):
86+
def size(self, *args, **kwargs) -> torch.Size:
8587
return self.data.size(*args, **kwargs)
8688

8789
@assert_tensor_type
88-
def dim(self):
90+
def dim(self) -> int:
8991
return self.data.dim()

mmcv/parallel/data_parallel.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from itertools import chain
3+
from typing import List, Tuple
34

45
from torch.nn.parallel import DataParallel
56

6-
from .scatter_gather import scatter_kwargs
7+
from .scatter_gather import ScatterInputs, scatter_kwargs
78

89

910
class MMDataParallel(DataParallel):
@@ -31,7 +32,7 @@ class MMDataParallel(DataParallel):
3132
dim (int): Dimension used to scatter the data. Defaults to 0.
3233
"""
3334

34-
def __init__(self, *args, dim=0, **kwargs):
35+
def __init__(self, *args, dim: int = 0, **kwargs):
3536
super().__init__(*args, dim=dim, **kwargs)
3637
self.dim = dim
3738

@@ -49,7 +50,8 @@ def forward(self, *inputs, **kwargs):
4950
else:
5051
return super().forward(*inputs, **kwargs)
5152

52-
def scatter(self, inputs, kwargs, device_ids):
53+
def scatter(self, inputs: ScatterInputs, kwargs: ScatterInputs,
54+
device_ids: List[int]) -> Tuple[tuple, tuple]:
5355
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
5456

5557
def train_step(self, *inputs, **kwargs):

mmcv/parallel/distributed.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import List, Tuple
3+
24
import torch
35
from torch.nn.parallel.distributed import (DistributedDataParallel,
46
_find_tensors)
57

68
from mmcv import print_log
79
from mmcv.utils import TORCH_VERSION, digit_version
8-
from .scatter_gather import scatter_kwargs
10+
from .scatter_gather import ScatterInputs, scatter_kwargs
911

1012

1113
class MMDistributedDataParallel(DistributedDataParallel):
@@ -18,12 +20,14 @@ class MMDistributedDataParallel(DistributedDataParallel):
1820
- It implement two APIs ``train_step()`` and ``val_step()``.
1921
"""
2022

21-
def to_kwargs(self, inputs, kwargs, device_id):
23+
def to_kwargs(self, inputs: ScatterInputs, kwargs: ScatterInputs,
24+
device_id: int) -> Tuple[tuple, tuple]:
2225
# Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
2326
# to move all tensors to device_id
2427
return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)
2528

26-
def scatter(self, inputs, kwargs, device_ids):
29+
def scatter(self, inputs: ScatterInputs, kwargs: ScatterInputs,
30+
device_ids: List[int]) -> Tuple[tuple, tuple]:
2731
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
2832

2933
def train_step(self, *inputs, **kwargs):

mmcv/parallel/distributed_deprecated.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import List, Sequence, Tuple
3+
24
import torch
35
import torch.distributed as dist
46
import torch.nn as nn
@@ -7,17 +9,17 @@
79

810
from mmcv.utils import TORCH_VERSION, digit_version
911
from .registry import MODULE_WRAPPERS
10-
from .scatter_gather import scatter_kwargs
12+
from .scatter_gather import ScatterInputs, scatter_kwargs
1113

1214

1315
@MODULE_WRAPPERS.register_module()
1416
class MMDistributedDataParallel(nn.Module):
1517

1618
def __init__(self,
17-
module,
18-
dim=0,
19-
broadcast_buffers=True,
20-
bucket_cap_mb=25):
19+
module: nn.Module,
20+
dim: int = 0,
21+
broadcast_buffers: bool = True,
22+
bucket_cap_mb: int = 25):
2123
super().__init__()
2224
self.module = module
2325
self.dim = dim
@@ -26,15 +28,16 @@ def __init__(self,
2628
self.broadcast_bucket_size = bucket_cap_mb * 1024 * 1024
2729
self._sync_params()
2830

29-
def _dist_broadcast_coalesced(self, tensors, buffer_size):
31+
def _dist_broadcast_coalesced(self, tensors: Sequence[torch.Tensor],
32+
buffer_size: int) -> None:
3033
for tensors in _take_tensors(tensors, buffer_size):
3134
flat_tensors = _flatten_dense_tensors(tensors)
3235
dist.broadcast(flat_tensors, 0)
3336
for tensor, synced in zip(
3437
tensors, _unflatten_dense_tensors(flat_tensors, tensors)):
3538
tensor.copy_(synced)
3639

37-
def _sync_params(self):
40+
def _sync_params(self) -> None:
3841
module_states = list(self.module.state_dict().values())
3942
if len(module_states) > 0:
4043
self._dist_broadcast_coalesced(module_states,
@@ -49,7 +52,8 @@ def _sync_params(self):
4952
self._dist_broadcast_coalesced(buffers,
5053
self.broadcast_bucket_size)
5154

52-
def scatter(self, inputs, kwargs, device_ids):
55+
def scatter(self, inputs: ScatterInputs, kwargs: ScatterInputs,
56+
device_ids: List[int]) -> Tuple[tuple, tuple]:
5357
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
5458

5559
def forward(self, *inputs, **kwargs):

mmcv/parallel/scatter_gather.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
import torch
2+
from typing import List, Tuple, Union
3+
4+
from torch import Tensor
35
from torch.nn.parallel._functions import Scatter as OrigScatter
46

57
from ._functions import Scatter
68
from .data_container import DataContainer
79

10+
ScatterInputs = Union[Tensor, DataContainer, tuple, list, dict]
11+
812

9-
def scatter(inputs, target_gpus, dim=0):
13+
def scatter(inputs: ScatterInputs,
14+
target_gpus: List[int],
15+
dim: int = 0) -> list:
1016
"""Scatter inputs to target gpus.
1117
1218
The only difference from original :func:`scatter` is to add support for
1319
:type:`~mmcv.parallel.DataContainer`.
1420
"""
1521

1622
def scatter_map(obj):
17-
if isinstance(obj, torch.Tensor):
23+
if isinstance(obj, Tensor):
1824
if target_gpus != [-1]:
1925
return OrigScatter.apply(target_gpus, None, dim, obj)
2026
else:
@@ -33,7 +39,7 @@ def scatter_map(obj):
3339
if isinstance(obj, dict) and len(obj) > 0:
3440
out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
3541
return out
36-
return [obj for targets in target_gpus]
42+
return [obj for _ in target_gpus]
3743

3844
# After scatter_map is called, a scatter_map cell will exist. This cell
3945
# has a reference to the actual function scatter_map, which has references
@@ -43,17 +49,22 @@ def scatter_map(obj):
4349
try:
4450
return scatter_map(inputs)
4551
finally:
46-
scatter_map = None
52+
scatter_map = None # type: ignore
4753

4854

49-
def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
55+
def scatter_kwargs(inputs: ScatterInputs,
56+
kwargs: ScatterInputs,
57+
target_gpus: List[int],
58+
dim: int = 0) -> Tuple[tuple, tuple]:
5059
"""Scatter with support for kwargs dictionary."""
5160
inputs = scatter(inputs, target_gpus, dim) if inputs else []
5261
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
5362
if len(inputs) < len(kwargs):
54-
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
63+
length = len(kwargs) - len(inputs)
64+
inputs.extend([() for _ in range(length)]) # type: ignore
5565
elif len(kwargs) < len(inputs):
56-
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
66+
length = len(inputs) - len(kwargs)
67+
kwargs.extend([{} for _ in range(length)]) # type: ignore
5768
inputs = tuple(inputs)
5869
kwargs = tuple(kwargs)
5970
return inputs, kwargs

mmcv/parallel/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from torch import nn
3+
24
from .registry import MODULE_WRAPPERS
35

46

5-
def is_module_wrapper(module):
7+
def is_module_wrapper(module: nn.Module) -> bool:
68
"""Check if a module is a module wrapper.
79
810
The following 3 modules in MMCV (and their subclasses) are regarded as

0 commit comments

Comments
 (0)