22
33from collections import abc
44from dataclasses import is_dataclass
5- from typing import Any , Mapping , NamedTuple , Optional , Sequence , Union , overload
5+ from typing import Any , Dict , Mapping , NamedTuple , Optional , Sequence , Union , overload
66
77import torch
88from torch import Tensor
1717
1818
1919@overload
20- def prepare_batch (batch : Sequence [Mapping [str , Any ]]) -> Mapping [str , Any ]:
21- ...
20+ def prepare_batch (batch : Mapping ) -> Dict [str , Any ]: ...
2221
2322
2423@overload
25- def prepare_batch (batch : Sequence [Dataclass ]) -> Dataclass :
26- ...
24+ def prepare_batch (batch : Dataclass ) -> Dataclass : ...
2725
2826
2927@overload
30- def prepare_batch (batch : Sequence [NamedTuple ]) -> NamedTuple :
31- ...
28+ def prepare_batch (batch : NamedTuple ) -> NamedTuple : ...
3229
3330
3431def prepare_batch (
35- batch : Batch ,
32+ batch : Union [ Mapping , Dataclass , NamedTuple ] ,
3633 device : Optional [Union [Device , str ]] = None ,
3734 non_blocking : bool = False ,
3835 memory_format = torch .preserve_format ,
3936) -> Batch :
4037 r"""Move batch data to execution device."""
41- names = sample_field_names (batch )
38+ names = sample_field_names (batch ) # type: ignore[arg-type]
4239 values = []
4340 for name in names :
44- value = sample_field_value (batch , name )
41+ value = sample_field_value (batch , name ) # type: ignore[arg-type]
4542 value = prepare_item (
4643 value , device = device , non_blocking = non_blocking , memory_format = memory_format
4744 )
4845 values .append (value )
49- return replace_all_sample_field_values (batch , values )
46+ return replace_all_sample_field_values (batch , values ) # type: ignore[arg-type]
5047
5148
5249def prepare_item (
@@ -56,11 +53,11 @@ def prepare_item(
5653 memory_format = torch .preserve_format ,
5754) -> Any :
5855 r"""Move batch item data to execution device."""
59- kwargs = dict (device = device , non_blocking = non_blocking , memory_format = memory_format )
56+ kwargs : dict = dict (device = device , non_blocking = non_blocking , memory_format = memory_format )
6057 if isinstance (value , Tensor ):
6158 value = value .to (** kwargs )
6259 elif isinstance (value , abc .Mapping ) or is_dataclass (value ) or is_namedtuple (value ):
63- value = prepare_batch (value , ** kwargs )
60+ value = prepare_batch (value , ** kwargs ) # type: ignore[arg-type]
6461 elif isinstance (value , Sequence ) and not isinstance (value , str ):
6562 value = [prepare_item (item , ** kwargs ) for item in value ]
6663 return value
0 commit comments