Skip to content

Commit fe314ff

Browse files
authored
Validate type hints with mypy, and use mypy in our CI (#164)
* Add mypy support to s3torchconnector * Add mypy type checking to python-checks.yml * Update DEVELOPMENT.md to add mypy tests --------- Co-authored-by: Simon Beal <[email protected]>
1 parent 22f5d89 commit fe314ff

File tree

16 files changed

+70
-34
lines changed

16 files changed

+70
-34
lines changed

.github/workflows/python-checks.yml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,18 +88,25 @@ jobs:
8888
- name: Install Python dependencies
8989
run: |
9090
python -m pip install --upgrade pip
91-
python -m pip install flake8 black
91+
python -m pip install flake8 black mypy
92+
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
93+
python -m pip install ./s3torchconnectorclient
94+
python -m pip install ./s3torchconnector
95+
9296
- name: Lint with flake8
9397
run: |
9498
# stop the build if there are Python syntax errors or undefined names
9599
flake8 s3torchconnector/ --count --select=E9,F63,F7,F82 --show-source --statistics
96100
flake8 s3torchconnectorclient/python --count --select=E9,F63,F7,F82 --show-source --statistics
97-
98101
- name: Lint with Black
99102
uses: psf/black@stable
100103
with:
101104
options: "--check --verbose"
102105
src: "."
106+
- name: Typecheck with mypy
107+
run: |
108+
mypy s3torchconnector/src
109+
mypy s3torchconnectorclient/python/src
103110
104111
dependencies:
105112
name: Python dependencies checks

doc/DEVELOPMENT.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,11 @@ pushing new Rust commits.
6767

6868
For Python code changes, run
6969
```bash
70-
black --verbose python/
70+
black --verbose .
7171
flake8 s3torchconnector/ --count --select=E9,F63,F7,F82 --show-source --statistics
7272
flake8 s3torchconnectorclient/python --count --select=E9,F63,F7,F82 --show-source --statistics
73+
mypy s3torchconnector/src
74+
mypy s3torchconnectorclient/python/src
7375
```
7476
to lint.
7577

s3torchconnector/pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ test = [
3232
"pytest-timeout",
3333
"hypothesis",
3434
"flake8",
35-
"black"
35+
"black",
36+
"mypy"
3637
]
3738

3839
e2e = [

s3torchconnector/src/s3torchconnector/_s3client/_mock_s3client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# // SPDX-License-Identifier: BSD
33

4+
from typing import Optional
5+
46
from s3torchconnectorclient._mountpoint_s3_client import (
57
MockMountpointS3Client,
68
MountpointS3Client,
@@ -21,7 +23,7 @@ def __init__(
2123
region: str,
2224
bucket: str,
2325
part_size: int = 8 * 1024 * 1024,
24-
user_agent: UserAgent = None,
26+
user_agent: Optional[UserAgent] = None,
2527
):
2628
super().__init__(region, user_agent=user_agent)
2729
self._mock_client = MockMountpointS3Client(

s3torchconnector/src/s3torchconnector/_s3client/_s3client.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,16 @@ def _identity(obj: Any) -> Any:
3232

3333

3434
class S3Client:
35-
def __init__(self, region: str, endpoint: str = None, user_agent: UserAgent = None):
35+
def __init__(
36+
self,
37+
region: str,
38+
endpoint: Optional[str] = None,
39+
user_agent: Optional[UserAgent] = None,
40+
):
3641
self._region = region
3742
self._endpoint = endpoint
38-
self._real_client = None
39-
self._client_pid = None
43+
self._real_client: Optional[MountpointS3Client] = None
44+
self._client_pid: Optional[int] = None
4045
user_agent = user_agent or UserAgent()
4146
self._user_agent_prefix = user_agent.prefix
4247

@@ -46,6 +51,7 @@ def _client(self) -> MountpointS3Client:
4651
self._client_pid = os.getpid()
4752
# `MountpointS3Client` does not survive forking, so re-create it if the PID has changed.
4853
self._real_client = self._client_builder()
54+
assert self._real_client is not None
4955
return self._real_client
5056

5157
@property

s3torchconnector/src/s3torchconnector/_user_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# // SPDX-License-Identifier: BSD
3-
from typing import List
3+
from typing import List, Optional
44

55
from ._version import __version__
66

77
# https://www.rfc-editor.org/rfc/rfc9110#name-user-agent
88

99

1010
class UserAgent:
11-
def __init__(self, comments: List[str] = None):
11+
def __init__(self, comments: Optional[List[str]] = None):
1212
if comments is not None and not isinstance(comments, list):
1313
raise ValueError("Argument comments must be a List[str]")
1414
self._user_agent_prefix = f"{__package__}/{__version__}"

s3torchconnector/src/s3torchconnector/s3checkpoint.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# // SPDX-License-Identifier: BSD
3+
from typing import Optional
34

45
from ._s3dataset_common import parse_s3_uri
56
from ._s3client import S3Client
@@ -16,7 +17,7 @@ class S3Checkpoint:
1617
torch.load, and torch.save.
1718
"""
1819

19-
def __init__(self, region: str, endpoint: str = None):
20+
def __init__(self, region: str, endpoint: Optional[str] = None):
2021
self.region = region
2122
self.endpoint = endpoint
2223
self._client = S3Client(region, endpoint)

s3torchconnector/src/s3torchconnector/s3iterable_dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# // SPDX-License-Identifier: BSD
33
from functools import partial
4-
from typing import Iterator, Any, Union, Iterable, Callable
4+
from typing import Iterator, Any, Union, Iterable, Callable, Optional
55
import logging
66

77
import torch.utils.data
@@ -29,7 +29,7 @@ def __init__(
2929
self,
3030
region: str,
3131
get_dataset_objects: Callable[[S3Client], Iterable[S3BucketKeyData]],
32-
endpoint: str = None,
32+
endpoint: Optional[str] = None,
3333
transform: Callable[[S3Reader], Any] = identity,
3434
):
3535
self._get_dataset_objects = get_dataset_objects
@@ -52,7 +52,7 @@ def from_objects(
5252
object_uris: Union[str, Iterable[str]],
5353
*,
5454
region: str,
55-
endpoint: str = None,
55+
endpoint: Optional[str] = None,
5656
transform: Callable[[S3Reader], Any] = identity,
5757
):
5858
"""Returns an instance of S3IterableDataset using the S3 URI(s) provided.
@@ -83,7 +83,7 @@ def from_prefix(
8383
s3_uri: str,
8484
*,
8585
region: str,
86-
endpoint: str = None,
86+
endpoint: Optional[str] = None,
8787
transform: Callable[[S3Reader], Any] = identity,
8888
):
8989
"""Returns an instance of S3IterableDataset using the S3 URI provided.

s3torchconnector/src/s3torchconnector/s3map_dataset.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# // SPDX-License-Identifier: BSD
33
from functools import partial
4-
from typing import List, Any, Callable, Iterable, Union
4+
from typing import List, Any, Callable, Iterable, Union, Optional
55
import logging
66

77
import torch.utils.data
@@ -30,15 +30,15 @@ def __init__(
3030
self,
3131
region: str,
3232
get_dataset_objects: Callable[[S3Client], Iterable[S3BucketKeyData]],
33-
endpoint: str = None,
33+
endpoint: Optional[str] = None,
3434
transform: Callable[[S3Reader], Any] = identity,
3535
):
3636
self._get_dataset_objects = get_dataset_objects
3737
self._transform = transform
3838
self._region = region
3939
self._endpoint = endpoint
4040
self._client = None
41-
self._bucket_key_pairs = None
41+
self._bucket_key_pairs: Optional[List[S3BucketKeyData]] = None
4242

4343
@property
4444
def region(self):
@@ -52,6 +52,7 @@ def endpoint(self):
5252
def _dataset_bucket_key_pairs(self) -> List[S3BucketKeyData]:
5353
if self._bucket_key_pairs is None:
5454
self._bucket_key_pairs = list(self._get_dataset_objects(self._get_client()))
55+
assert self._bucket_key_pairs is not None
5556
return self._bucket_key_pairs
5657

5758
@classmethod
@@ -60,7 +61,7 @@ def from_objects(
6061
object_uris: Union[str, Iterable[str]],
6162
*,
6263
region: str,
63-
endpoint: str = None,
64+
endpoint: Optional[str] = None,
6465
transform: Callable[[S3Reader], Any] = identity,
6566
):
6667
"""Returns an instance of S3MapDataset using the S3 URI(s) provided.
@@ -91,7 +92,7 @@ def from_prefix(
9192
s3_uri: str,
9293
*,
9394
region: str,
94-
endpoint: str = None,
95+
endpoint: Optional[str] = None,
9596
transform: Callable[[S3Reader], Any] = identity,
9697
):
9798
"""Returns an instance of S3MapDataset using the S3 URI provided.

s3torchconnector/src/s3torchconnector/s3reader.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import io
55
from functools import cached_property
66
from io import SEEK_CUR, SEEK_END, SEEK_SET
7-
from typing import Callable, Optional
7+
from typing import Callable, Optional, Iterable, Iterator
88

99
from s3torchconnectorclient._mountpoint_s3_client import ObjectInfo, GetObjectStream
1010

@@ -16,18 +16,18 @@ def __init__(
1616
self,
1717
bucket: str,
1818
key: str,
19-
get_object_info: Callable[[], ObjectInfo] = None,
20-
get_stream: Callable[[], GetObjectStream] = None,
19+
get_object_info: Callable[[], ObjectInfo],
20+
get_stream: Callable[[], GetObjectStream],
2121
):
2222
if not bucket:
2323
raise ValueError("Bucket should be specified")
2424
self._bucket = bucket
2525
self._key = key
2626
self._get_object_info = get_object_info
2727
self._get_stream = get_stream
28-
self._stream = None
28+
self._stream: Optional[Iterator[bytes]] = None
2929
self._buffer = io.BytesIO()
30-
self._size = None
30+
self._size: Optional[int] = None
3131
# Invariant: _position == _buffer._tell() unless _position_at_end()
3232
self._position = 0
3333

@@ -77,6 +77,7 @@ def read(self, size: Optional[int] = None) -> bytes:
7777
return b""
7878

7979
self.prefetch()
80+
assert self._stream is not None
8081
cur_pos = self._position
8182
if size is None or size < 0:
8283
# Special case read() all to use O(n) algorithm
@@ -138,6 +139,7 @@ def seek(self, offset: int, whence: int = SEEK_SET, /) -> int:
138139

139140
def _prefetch_to_offset(self, offset: int) -> None:
140141
self.prefetch()
142+
assert self._stream is not None
141143
buf_size = self._buffer.seek(0, SEEK_END)
142144
try:
143145
while offset > buf_size:

0 commit comments

Comments
 (0)