Skip to content

Commit 58d176a

Browse files
muddyfishdnnanuti
authored andcommitted
Lightning support (#132)
- Add testing which splits out utils for checkpointing into a common module - Add lightning optional dependency list for lightning checkpoints - Add delete tests - Add tests that validate the error message we get if lightning is not installed --------- Co-authored-by: dnnanuti <[email protected]> Co-authored-by: Simon Beal <[email protected]>
1 parent ffd6683 commit 58d176a

File tree

14 files changed

+403
-69
lines changed

14 files changed

+403
-69
lines changed

.github/workflows/python-checks.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ jobs:
6060
- name: s3torchbenchmarking unit tests
6161
run: pytest s3torchbenchmarking/tst --hypothesis-profile ci --hypothesis-show-statistics -c ./
6262

63+
- name: Install Lightning dependency
64+
run: |
65+
python -m pip install -e "s3torchconnector[lightning]"
66+
- name: s3torchconnector lightning unit tests
67+
run: pytest s3torchconnector/tst/unit/lightning --hypothesis-profile ci --hypothesis-show-statistics
68+
6369
lint:
6470
name: Python lints
6571
runs-on: ubuntu-22.04

s3torchconnector/pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ e2e = [
4545
"pytest-xdist"
4646
]
4747

48+
lightning = [
49+
"lightning >= 2.0"
50+
]
51+
4852
[tool.setuptools.packages]
4953
# Pure Python packages/modules
5054
find = { where = ["src"] }

s3torchconnector/src/s3torchconnector/_s3client/_s3client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,7 @@ def list_objects(
105105
def head_object(self, bucket: str, key: str) -> ObjectInfo:
106106
log.debug(f"HeadObject s3://{bucket}/{key}")
107107
return self._client.head_object(bucket, key)
108+
109+
def delete_object(self, bucket: str, key: str) -> None:
110+
log.debug(f"DeleteObject s3://{bucket}/{key}")
111+
self._client.delete_object(bucket, key)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
4+
# Get a nice error message if lightning isn't available.
5+
import lightning
6+
7+
from .s3_lightning_checkpoint import S3LightningCheckpoint
8+
9+
__all__ = [
10+
"S3LightningCheckpoint",
11+
]
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
4+
from typing import Optional, Dict, Any
5+
6+
import torch
7+
8+
from lightning.pytorch.plugins.io import CheckpointIO
9+
10+
from .._s3client import S3Client
11+
from .._s3dataset_common import parse_s3_uri
12+
13+
14+
class S3LightningCheckpoint(CheckpointIO):
15+
"""A checkpoint manager for S3 using the :class:`CheckpointIO` interface."""
16+
17+
def __init__(self, region: str):
18+
self.region = region
19+
self._client = S3Client(region)
20+
21+
def save_checkpoint(
22+
self,
23+
checkpoint: Dict[str, Any],
24+
path: str,
25+
storage_options: Optional[Any] = None,
26+
) -> None:
27+
"""Save model/training states as a checkpoint file through state-dump and upload to S3.
28+
29+
Args:
30+
checkpoint (Dict[str, Any]): Containing model and trainer state
31+
path (str): Write-target S3 uri
32+
storage_options: Optional parameters when saving the model/training states.
33+
"""
34+
self._validate_path(path)
35+
bucket, key = parse_s3_uri(path)
36+
with self._client.put_object(bucket, key) as s3writer:
37+
torch.save(checkpoint, s3writer)
38+
39+
def load_checkpoint(
40+
self,
41+
path: str,
42+
map_location: Optional[Any] = None,
43+
) -> Dict[str, Any]:
44+
"""Load checkpoint from an S3 location when resuming or loading ckpt for test/validate/predict stages.
45+
46+
Args:
47+
path (str): S3 uri to checkpoint
48+
map_location: A function, :class:`torch.device`, string or a dict specifying how to remap storage locations.
49+
50+
Returns:
51+
Dict[str, Any]: The loaded checkpoint
52+
53+
Raises:
54+
S3Exception: An error occurred accessing S3.
55+
"""
56+
self._validate_path(path)
57+
bucket, key = parse_s3_uri(path)
58+
s3reader = self._client.get_object(bucket, key)
59+
return torch.load(s3reader, map_location)
60+
61+
def remove_checkpoint(self, path: str) -> None:
62+
"""Remove checkpoint file from the S3 uri.
63+
64+
Args:
65+
path (str): S3 uri to checkpoint
66+
67+
Raises:
68+
S3Exception: An error occurred accessing S3.
69+
"""
70+
self._validate_path(path)
71+
bucket, key = parse_s3_uri(path)
72+
self._client.delete_object(bucket, key)
73+
74+
def teardown(self) -> None:
75+
"""This method is called to teardown the process."""
76+
pass
77+
78+
@staticmethod
79+
def _validate_path(path: str) -> None:
80+
if not isinstance(path, str):
81+
raise TypeError(
82+
f"{type(path).__name__!r} is not a supported type for 'path'. Must be a string formatted as an S3 uri."
83+
)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
4+
from contextlib import contextmanager
5+
from unittest.mock import patch
6+
import torch
7+
from hypothesis.strategies import one_of, just
8+
9+
byteorders = one_of(just("little"), just("big"))
10+
11+
12+
@contextmanager
13+
def _patch_byteorder(byteorder: str):
14+
with patch("torch.serialization.sys") as mock_sys:
15+
mock_sys.byteorder = byteorder
16+
yield
17+
18+
19+
def save_with_byteorder(data, fobj, byteorder: str, use_modern_pytorch_format: bool):
20+
with _patch_byteorder(byteorder):
21+
torch.save(data, fobj, _use_new_zipfile_serialization=use_modern_pytorch_format)
22+
23+
24+
def load_with_byteorder(fobj, byteorder):
25+
with _patch_byteorder(byteorder):
26+
return torch.load(fobj)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
4+
from hypothesis.strategies import (
5+
integers,
6+
binary,
7+
none,
8+
characters,
9+
complex_numbers,
10+
floats,
11+
booleans,
12+
decimals,
13+
fractions,
14+
deferred,
15+
frozensets,
16+
tuples,
17+
dictionaries,
18+
lists,
19+
uuids,
20+
sets,
21+
text,
22+
)
23+
24+
scalars = (
25+
none()
26+
| booleans()
27+
| integers()
28+
# Disallow nan as it doesn't have self-equality
29+
| floats(allow_nan=False)
30+
| complex_numbers(allow_nan=False)
31+
| decimals(allow_nan=False)
32+
| fractions()
33+
| characters()
34+
| binary(max_size=10)
35+
| text(max_size=10)
36+
| uuids()
37+
)
38+
39+
hashable = deferred(
40+
lambda: (scalars | frozensets(hashable, max_size=5) | tuples(hashable))
41+
)
42+
43+
python_primitives = deferred(
44+
lambda: (
45+
hashable
46+
| sets(hashable, max_size=5)
47+
| lists(python_primitives, max_size=5)
48+
| dictionaries(keys=hashable, values=python_primitives, max_size=3)
49+
)
50+
)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
4+
from io import BytesIO
5+
from operator import eq
6+
from pathlib import Path
7+
from typing import Callable, Any
8+
9+
import hypothesis
10+
import pytest
11+
import torch
12+
from hypothesis import given, HealthCheck
13+
from lightning.fabric.plugins import CheckpointIO
14+
from lightning.pytorch.plugins import AsyncCheckpointIO
15+
16+
from s3torchconnector._s3client import MockS3Client
17+
from s3torchconnector.lightning import S3LightningCheckpoint
18+
from .._checkpoint_byteorder_patch import (
19+
byteorders,
20+
save_with_byteorder,
21+
load_with_byteorder,
22+
_patch_byteorder,
23+
)
24+
from .._hypothesis_python_primitives import python_primitives
25+
from s3torchconnectorclient import S3Exception
26+
27+
TEST_BUCKET = "test-bucket"
28+
TEST_KEY = "test-key"
29+
TEST_REGION = "us-east-1"
30+
31+
32+
@pytest.fixture()
33+
def client() -> MockS3Client:
34+
return MockS3Client(TEST_REGION, TEST_BUCKET)
35+
36+
37+
@pytest.fixture()
38+
def lightning_checkpoint(client) -> S3LightningCheckpoint:
39+
s3_lightning_checkpoint = S3LightningCheckpoint(TEST_REGION)
40+
s3_lightning_checkpoint._client = client
41+
return s3_lightning_checkpoint
42+
43+
44+
@pytest.fixture()
45+
def async_lightning_checkpoint(lightning_checkpoint) -> Callable[[], AsyncCheckpointIO]:
46+
return lambda: AsyncCheckpointIO(lightning_checkpoint)
47+
48+
49+
@hypothesis.settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
50+
@given(python_primitives, byteorders)
51+
def test_lightning_checkpointing_saves_python_primitives(
52+
client, lightning_checkpoint, data, byteorder
53+
):
54+
_test_save(client, lightning_checkpoint, data, byteorder)
55+
56+
57+
@hypothesis.settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
58+
@given(byteorders)
59+
def test_lightning_checkpointing_saves_tensor(client, lightning_checkpoint, byteorder):
60+
tensor = torch.rand(2, 4)
61+
_test_save(client, lightning_checkpoint, tensor, byteorder, equal=torch.equal)
62+
63+
64+
@hypothesis.settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
65+
@given(byteorders)
66+
def test_async_lightning_checkpointing_saves_tensor(
67+
client, async_lightning_checkpoint, byteorder
68+
):
69+
tensor = torch.rand(2, 4)
70+
_test_save(
71+
client, async_lightning_checkpoint(), tensor, byteorder, equal=torch.equal
72+
)
73+
74+
75+
@hypothesis.settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
76+
@given(byteorders)
77+
def test_lightning_checkpointing_saves_untyped_storage(
78+
client, lightning_checkpoint, byteorder
79+
):
80+
storage = torch.UntypedStorage([1, 2, 3])
81+
_test_save(
82+
client,
83+
lightning_checkpoint,
84+
storage,
85+
byteorder,
86+
equal=lambda a, b: list(a) == list(b),
87+
)
88+
89+
90+
@hypothesis.settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
91+
@given(python_primitives, byteorders)
92+
def test_lightning_checkpointing_loads_python_primitives(
93+
client, lightning_checkpoint, data, byteorder
94+
):
95+
_test_load(client, lightning_checkpoint, data, byteorder)
96+
97+
98+
@hypothesis.settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
99+
@given(byteorders)
100+
def test_lightning_checkpointing_loads_tensor(client, lightning_checkpoint, byteorder):
101+
tensor = torch.rand(2, 4)
102+
_test_load(client, lightning_checkpoint, tensor, byteorder, equal=torch.equal)
103+
104+
105+
@hypothesis.settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
106+
@given(byteorders)
107+
def test_lightning_checkpointing_loads_untyped_storage(
108+
client, lightning_checkpoint, byteorder
109+
):
110+
storage = torch.UntypedStorage([1, 2, 3])
111+
_test_load(
112+
client,
113+
lightning_checkpoint,
114+
storage,
115+
byteorder,
116+
equal=lambda a, b: list(a) == list(b),
117+
)
118+
119+
120+
def test_removes_checkpoint(client, lightning_checkpoint):
121+
lightning_checkpoint.remove_checkpoint(f"s3://{TEST_BUCKET}/{TEST_KEY}")
122+
123+
with pytest.raises(S3Exception) as error:
124+
client.get_object(TEST_BUCKET, TEST_KEY).read()
125+
assert str(error.value) == "Service error: The key does not exist"
126+
127+
128+
@pytest.mark.parametrize(
129+
"checkpoint_method_name, kwargs",
130+
[
131+
("save_checkpoint", {"path": Path("bucket", "key"), "checkpoint": None}),
132+
("load_checkpoint", {"path": Path("/", "bucket", "key")}),
133+
("remove_checkpoint", {"path": Path()}),
134+
("remove_checkpoint", {"path": ["not", "a", "string"]}),
135+
],
136+
)
137+
def test_invalid_path(lightning_checkpoint, checkpoint_method_name, kwargs):
138+
checkpoint_method = getattr(lightning_checkpoint, checkpoint_method_name)
139+
with pytest.raises(
140+
TypeError,
141+
match="is not a supported type for 'path'. Must be a string formatted as an S3 uri",
142+
):
143+
checkpoint_method(**kwargs)
144+
145+
146+
def test_teardown(lightning_checkpoint):
147+
lightning_checkpoint.teardown()
148+
# Assert no exception is thrown - implicit
149+
150+
151+
def _test_save(
152+
client,
153+
checkpoint: CheckpointIO,
154+
data,
155+
byteorder: str,
156+
*,
157+
equal: Callable[[Any, Any], bool] = eq,
158+
):
159+
with _patch_byteorder(byteorder):
160+
checkpoint.save_checkpoint(data, f"s3://{TEST_BUCKET}/{TEST_KEY}")
161+
# For async checkpointing, ensure that we finish writing the checkpoint before we un-patch the byteorder
162+
checkpoint.teardown()
163+
164+
serialised = BytesIO(b"".join(client.get_object(TEST_BUCKET, TEST_KEY)))
165+
assert equal(load_with_byteorder(serialised, byteorder), data)
166+
167+
168+
def _test_load(
169+
client,
170+
checkpoint: CheckpointIO,
171+
data,
172+
byteorder: str,
173+
*,
174+
equal: Callable[[Any, Any], bool] = eq,
175+
):
176+
# Put some data to mock bucket and use mock client
177+
serialised = BytesIO()
178+
save_with_byteorder(data, serialised, byteorder, use_modern_pytorch_format=True)
179+
serialised.seek(0)
180+
client.add_object(TEST_KEY, serialised.read())
181+
182+
with _patch_byteorder(byteorder):
183+
returned_data = checkpoint.load_checkpoint(f"s3://{TEST_BUCKET}/{TEST_KEY}")
184+
185+
assert equal(returned_data, data)

0 commit comments

Comments
 (0)