Skip to content

Commit bdc0852

Browse files
committed
fix: Type annotation of prepare_batch()
1 parent 83d4b48 commit bdc0852

File tree

3 files changed

+15
-18
lines changed

3 files changed

+15
-18
lines changed

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[flake8]
22
max-line-length = 100
33
select = C,E,F,W,B,B950
4-
ignore = E203, E402, E501, W503
4+
ignore = E203, E402, E501, E704, W503, BLK100

src/deepali/data/prepare.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections import abc
44
from dataclasses import is_dataclass
5-
from typing import Any, Mapping, NamedTuple, Optional, Sequence, Union, overload
5+
from typing import Any, Dict, Mapping, NamedTuple, Optional, Sequence, Union, overload
66

77
import torch
88
from torch import Tensor
@@ -17,36 +17,33 @@
1717

1818

1919
@overload
20-
def prepare_batch(batch: Sequence[Mapping[str, Any]]) -> Mapping[str, Any]:
21-
...
20+
def prepare_batch(batch: Mapping) -> Dict[str, Any]: ...
2221

2322

2423
@overload
25-
def prepare_batch(batch: Sequence[Dataclass]) -> Dataclass:
26-
...
24+
def prepare_batch(batch: Dataclass) -> Dataclass: ...
2725

2826

2927
@overload
30-
def prepare_batch(batch: Sequence[NamedTuple]) -> NamedTuple:
31-
...
28+
def prepare_batch(batch: NamedTuple) -> NamedTuple: ...
3229

3330

3431
def prepare_batch(
35-
batch: Batch,
32+
batch: Union[Mapping, Dataclass, NamedTuple],
3633
device: Optional[Union[Device, str]] = None,
3734
non_blocking: bool = False,
3835
memory_format=torch.preserve_format,
3936
) -> Batch:
4037
r"""Move batch data to execution device."""
41-
names = sample_field_names(batch)
38+
names = sample_field_names(batch) # type: ignore[arg-type]
4239
values = []
4340
for name in names:
44-
value = sample_field_value(batch, name)
41+
value = sample_field_value(batch, name) # type: ignore[arg-type]
4542
value = prepare_item(
4643
value, device=device, non_blocking=non_blocking, memory_format=memory_format
4744
)
4845
values.append(value)
49-
return replace_all_sample_field_values(batch, values)
46+
return replace_all_sample_field_values(batch, values) # type: ignore[arg-type]
5047

5148

5249
def prepare_item(
@@ -56,11 +53,11 @@ def prepare_item(
5653
memory_format=torch.preserve_format,
5754
) -> Any:
5855
r"""Move batch item data to execution device."""
59-
kwargs = dict(device=device, non_blocking=non_blocking, memory_format=memory_format)
56+
kwargs: dict = dict(device=device, non_blocking=non_blocking, memory_format=memory_format)
6057
if isinstance(value, Tensor):
6158
value = value.to(**kwargs)
6259
elif isinstance(value, abc.Mapping) or is_dataclass(value) or is_namedtuple(value):
63-
value = prepare_batch(value, **kwargs)
60+
value = prepare_batch(value, **kwargs) # type: ignore[arg-type]
6461
elif isinstance(value, Sequence) and not isinstance(value, str):
6562
value = [prepare_item(item, **kwargs) for item in value]
6663
return value

src/deepali/data/sample.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
)
1717

1818

19-
def sample_field_names(sample: Sample) -> Tuple[str]:
19+
def sample_field_names(sample: Sample) -> Tuple[str, ...]:
2020
r"""Get names of fields in data sample."""
2121
if is_dataclass(sample):
22-
return tuple((field.name for field in fields(sample)))
22+
return tuple((field.name for field in fields(sample))) # type: ignore[arg-type]
2323
if is_namedtuple(sample):
24-
return sample._fields
24+
return sample._fields # type: ignore[arg-type]
2525
if not isinstance(sample, Mapping):
2626
raise TypeError("Dataset 'sample' must be dataclass, Mapping, or NamedTuple")
2727
return tuple(sample.keys())
@@ -45,7 +45,7 @@ def replace_all_sample_field_values(sample: Sample, values: Sequence[Any]) -> Sa
4545
setattr(result, name, value)
4646
return result
4747
if is_namedtuple(sample):
48-
return sample._replace(**{name: value for name, value in zip(names, values)})
48+
return sample._replace(**{name: value for name, value in zip(names, values)}) # type: ignore[arg-type]
4949
if isinstance(sample, OrderedDict):
5050
return OrderedDict([(name, value) for name, value in zip(names, values)])
5151
return {name: value for name, value in zip(names, values)}

0 commit comments

Comments
 (0)