Skip to content

Commit 53885af

Browse files
stancldpre-commit-ci[bot]Bordacarmocca
authored
Fix mypy typing for utilities.apply_func (#8781)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka <[email protected]> Co-authored-by: Carlos Mocholi <[email protected]>
1 parent dfffb94 commit 53885af

File tree

2 files changed

+30
-23
lines changed

2 files changed

+30
-23
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ module = [
6565
"pytorch_lightning.loops.closure",
6666
"pytorch_lightning.trainer.evaluation_loop",
6767
"pytorch_lightning.trainer.connectors.logger_connector",
68+
"pytorch_lightning.utilities.apply_func",
6869
"pytorch_lightning.utilities.argparse",
6970
"pytorch_lightning.utilities.cli",
7071
"pytorch_lightning.utilities.cloud_io",

pytorch_lightning/utilities/apply_func.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from collections.abc import Mapping, Sequence
1919
from copy import copy
2020
from functools import partial
21-
from typing import Any, Callable, Optional, Union
21+
from typing import Any, Callable, List, Optional, Tuple, Union
2222

2323
import numpy as np
2424
import torch
@@ -35,19 +35,23 @@
3535
Batch = type(None)
3636

3737

38-
def to_dtype_tensor(value, dtype: torch.dtype = None, device: torch.device = None):
38+
def to_dtype_tensor(
39+
value: Union[int, float, List[Union[int, float]]],
40+
dtype: Optional[torch.dtype] = None,
41+
device: Union[str, torch.device] = None,
42+
) -> torch.Tensor:
3943
if device is None:
4044
raise MisconfigurationException("device (torch.device) should be provided.")
4145
return torch.tensor(value, dtype=dtype, device=device)
4246

4347

44-
def from_numpy(value, device: torch.device = None):
48+
def from_numpy(value: np.ndarray, device: Union[str, torch.device] = None) -> torch.Tensor:
4549
if device is None:
4650
raise MisconfigurationException("device (torch.device) should be provided.")
4751
return torch.from_numpy(value).to(device)
4852

4953

50-
CONVERSION_DTYPES = [
54+
CONVERSION_DTYPES: List[Tuple[Any, Callable[[Any], torch.Tensor]]] = [
5155
# bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group
5256
(bool, partial(to_dtype_tensor, dtype=torch.uint8)),
5357
(int, partial(to_dtype_tensor, dtype=torch.int)),
@@ -61,19 +65,19 @@ def _is_namedtuple(obj: object) -> bool:
6165
return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
6266

6367

64-
def _is_dataclass_instance(obj):
68+
def _is_dataclass_instance(obj: object) -> bool:
6569
# https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions
6670
return dataclasses.is_dataclass(obj) and not isinstance(obj, type)
6771

6872

6973
def apply_to_collection(
7074
data: Any,
71-
dtype: Union[type, tuple],
75+
dtype: Union[type, Any, Tuple[Union[type, Any]]],
7276
function: Callable,
73-
*args,
74-
wrong_dtype: Optional[Union[type, tuple]] = None,
77+
*args: Any,
78+
wrong_dtype: Optional[Union[type, Tuple[type]]] = None,
7579
include_none: bool = True,
76-
**kwargs
80+
**kwargs: Any,
7781
) -> Any:
7882
"""
7983
Recursively applies a function to all elements of a certain dtype.
@@ -121,7 +125,7 @@ def apply_to_collection(
121125
return elem_type(*out) if is_namedtuple else elem_type(out)
122126

123127
if _is_dataclass_instance(data):
124-
out = {}
128+
out_dict = {}
125129
for field in data.__dataclass_fields__:
126130
v = apply_to_collection(
127131
getattr(data, field),
@@ -130,11 +134,11 @@ def apply_to_collection(
130134
*args,
131135
wrong_dtype=wrong_dtype,
132136
include_none=include_none,
133-
**kwargs
137+
**kwargs,
134138
)
135139
if include_none or v is not None:
136-
out[field] = v
137-
return elem_type(**out)
140+
out_dict[field] = v
141+
return elem_type(**out_dict)
138142

139143
# data is neither of dtype, nor a collection
140144
return data
@@ -143,11 +147,11 @@ def apply_to_collection(
143147
def apply_to_collections(
144148
data1: Optional[Any],
145149
data2: Optional[Any],
146-
dtype: Union[type, tuple],
150+
dtype: Union[type, Any, Tuple[Union[type, Any]]],
147151
function: Callable,
148-
*args,
149-
wrong_dtype: Optional[Union[type, tuple]] = None,
150-
**kwargs
152+
*args: Any,
153+
wrong_dtype: Optional[Union[type, Tuple[type]]] = None,
154+
**kwargs: Any,
151155
) -> Any:
152156
"""
153157
Zips two collections and applies a function to their items of a certain dtype.
@@ -169,7 +173,9 @@ def apply_to_collections(
169173
AssertionError:
170174
If sequence collections have different data sizes.
171175
"""
172-
if data1 is None and data2 is not None:
176+
if data1 is None:
177+
if data2 is None:
178+
return
173179
# in case they were passed reversed
174180
data1, data2 = data2, None
175181

@@ -220,14 +226,14 @@ class TransferableDataType(ABC):
220226
"""
221227

222228
@classmethod
223-
def __subclasshook__(cls, subclass):
229+
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
224230
if cls is TransferableDataType:
225231
to = getattr(subclass, "to", None)
226232
return callable(to)
227233
return NotImplemented
228234

229235

230-
def move_data_to_device(batch: Any, device: torch.device):
236+
def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any:
231237
"""
232238
Transfers a collection of data to the given device. Any object that defines a method
233239
``to(device)`` will be moved and all other objects in the collection will be left untouched.
@@ -245,7 +251,7 @@ def move_data_to_device(batch: Any, device: torch.device):
245251
- :class:`torch.device`
246252
"""
247253

248-
def batch_to(data):
254+
def batch_to(data: Any) -> Any:
249255
# try to move torchtext data first
250256
if _TORCHTEXT_AVAILABLE and isinstance(data, Batch):
251257

@@ -269,14 +275,14 @@ def batch_to(data):
269275
return apply_to_collection(batch, dtype=dtype, function=batch_to)
270276

271277

272-
def convert_to_tensors(data: Any, device: torch.device) -> Any:
278+
def convert_to_tensors(data: Any, device: Union[str, torch.device]) -> Any:
273279
if device is None:
274280
raise MisconfigurationException("`torch.device` should be provided.")
275281

276282
for src_dtype, conversion_func in CONVERSION_DTYPES:
277283
data = apply_to_collection(data, src_dtype, conversion_func, device=device)
278284

279-
def _move_to_device_and_make_contiguous(t: torch.Tensor, device: torch.device) -> torch.Tensor:
285+
def _move_to_device_and_make_contiguous(t: torch.Tensor, device: Union[str, torch.device]) -> torch.Tensor:
280286
return t.to(device).contiguous()
281287

282288
data = apply_to_collection(data, torch.Tensor, _move_to_device_and_make_contiguous, device=device)

0 commit comments

Comments
 (0)