1818from collections .abc import Mapping , Sequence
1919from copy import copy
2020from functools import partial
21- from typing import Any , Callable , Optional , Union
21+ from typing import Any , Callable , List , Optional , Tuple , Union
2222
2323import numpy as np
2424import torch
3535 Batch = type (None )
3636
3737
38- def to_dtype_tensor (value , dtype : torch .dtype = None , device : torch .device = None ):
38+ def to_dtype_tensor (
39+ value : Union [int , float , List [Union [int , float ]]],
40+ dtype : Optional [torch .dtype ] = None ,
41+ device : Union [str , torch .device ] = None ,
42+ ) -> torch .Tensor :
3943 if device is None :
4044 raise MisconfigurationException ("device (torch.device) should be provided." )
4145 return torch .tensor (value , dtype = dtype , device = device )
4246
4347
44- def from_numpy (value , device : torch .device = None ):
48+ def from_numpy (value : np . ndarray , device : Union [ str , torch .device ] = None ) -> torch . Tensor :
4549 if device is None :
4650 raise MisconfigurationException ("device (torch.device) should be provided." )
4751 return torch .from_numpy (value ).to (device )
4852
4953
50- CONVERSION_DTYPES = [
54+ CONVERSION_DTYPES : List [ Tuple [ Any , Callable [[ Any ], torch . Tensor ]]] = [
5155 # bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group
5256 (bool , partial (to_dtype_tensor , dtype = torch .uint8 )),
5357 (int , partial (to_dtype_tensor , dtype = torch .int )),
@@ -61,19 +65,19 @@ def _is_namedtuple(obj: object) -> bool:
6165 return isinstance (obj , tuple ) and hasattr (obj , "_asdict" ) and hasattr (obj , "_fields" )
6266
6367
64- def _is_dataclass_instance (obj ) :
68+ def _is_dataclass_instance (obj : object ) -> bool :
6569 # https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions
6670 return dataclasses .is_dataclass (obj ) and not isinstance (obj , type )
6771
6872
6973def apply_to_collection (
7074 data : Any ,
71- dtype : Union [type , tuple ],
75+ dtype : Union [type , Any , Tuple [ Union [ type , Any ]] ],
7276 function : Callable ,
73- * args ,
74- wrong_dtype : Optional [Union [type , tuple ]] = None ,
77+ * args : Any ,
78+ wrong_dtype : Optional [Union [type , Tuple [ type ] ]] = None ,
7579 include_none : bool = True ,
76- ** kwargs
80+ ** kwargs : Any ,
7781) -> Any :
7882 """
7983 Recursively applies a function to all elements of a certain dtype.
@@ -121,7 +125,7 @@ def apply_to_collection(
121125 return elem_type (* out ) if is_namedtuple else elem_type (out )
122126
123127 if _is_dataclass_instance (data ):
124- out = {}
128+ out_dict = {}
125129 for field in data .__dataclass_fields__ :
126130 v = apply_to_collection (
127131 getattr (data , field ),
@@ -130,11 +134,11 @@ def apply_to_collection(
130134 * args ,
131135 wrong_dtype = wrong_dtype ,
132136 include_none = include_none ,
133- ** kwargs
137+ ** kwargs ,
134138 )
135139 if include_none or v is not None :
136- out [field ] = v
137- return elem_type (** out )
140+ out_dict [field ] = v
141+ return elem_type (** out_dict )
138142
139143 # data is neither of dtype, nor a collection
140144 return data
@@ -143,11 +147,11 @@ def apply_to_collection(
143147def apply_to_collections (
144148 data1 : Optional [Any ],
145149 data2 : Optional [Any ],
146- dtype : Union [type , tuple ],
150+ dtype : Union [type , Any , Tuple [ Union [ type , Any ]] ],
147151 function : Callable ,
148- * args ,
149- wrong_dtype : Optional [Union [type , tuple ]] = None ,
150- ** kwargs
152+ * args : Any ,
153+ wrong_dtype : Optional [Union [type , Tuple [ type ] ]] = None ,
154+ ** kwargs : Any ,
151155) -> Any :
152156 """
153157 Zips two collections and applies a function to their items of a certain dtype.
@@ -169,7 +173,9 @@ def apply_to_collections(
169173 AssertionError:
170174 If sequence collections have different data sizes.
171175 """
172- if data1 is None and data2 is not None :
176+ if data1 is None :
177+ if data2 is None :
178+ return
173179 # in case they were passed reversed
174180 data1 , data2 = data2 , None
175181
@@ -220,14 +226,14 @@ class TransferableDataType(ABC):
220226 """
221227
222228 @classmethod
223- def __subclasshook__ (cls , subclass ) :
229+ def __subclasshook__ (cls , subclass : Any ) -> Union [ bool , Any ] :
224230 if cls is TransferableDataType :
225231 to = getattr (subclass , "to" , None )
226232 return callable (to )
227233 return NotImplemented
228234
229235
230- def move_data_to_device (batch : Any , device : torch .device ) :
236+ def move_data_to_device (batch : Any , device : Union [ str , torch .device ]) -> Any :
231237 """
232238 Transfers a collection of data to the given device. Any object that defines a method
233239 ``to(device)`` will be moved and all other objects in the collection will be left untouched.
@@ -245,7 +251,7 @@ def move_data_to_device(batch: Any, device: torch.device):
245251 - :class:`torch.device`
246252 """
247253
248- def batch_to (data ) :
254+ def batch_to (data : Any ) -> Any :
249255 # try to move torchtext data first
250256 if _TORCHTEXT_AVAILABLE and isinstance (data , Batch ):
251257
@@ -269,14 +275,14 @@ def batch_to(data):
269275 return apply_to_collection (batch , dtype = dtype , function = batch_to )
270276
271277
272- def convert_to_tensors (data : Any , device : torch .device ) -> Any :
278+ def convert_to_tensors (data : Any , device : Union [ str , torch .device ] ) -> Any :
273279 if device is None :
274280 raise MisconfigurationException ("`torch.device` should be provided." )
275281
276282 for src_dtype , conversion_func in CONVERSION_DTYPES :
277283 data = apply_to_collection (data , src_dtype , conversion_func , device = device )
278284
279- def _move_to_device_and_make_contiguous (t : torch .Tensor , device : torch .device ) -> torch .Tensor :
285+ def _move_to_device_and_make_contiguous (t : torch .Tensor , device : Union [ str , torch .device ] ) -> torch .Tensor :
280286 return t .to (device ).contiguous ()
281287
282288 data = apply_to_collection (data , torch .Tensor , _move_to_device_and_make_contiguous , device = device )
0 commit comments