Skip to content

Commit a59cd54

Browse files
matthieu-d4rIsaevIlya
authored andcommitted
feat(dcp): add support for S3FileSystem (#251)
Add `S3FileSystem`, `S3StorageWriter`, and `S3StorageReader` classes. Add relating unit and (basic) e2e tests.
1 parent 7f00661 commit a59cd54

File tree

10 files changed

+545
-4
lines changed

10 files changed

+545
-4
lines changed

.github/workflows/python-checks.yml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,16 @@ jobs:
6262
- name: s3torchconnector lightning unit tests
6363
run: pytest s3torchconnector/tst/unit/lightning --hypothesis-profile ci --hypothesis-show-statistics -c ./
6464

65+
- name: Install DCP dependencies
66+
run: |
67+
python -m pip install './s3torchconnector[dcp-test]'
68+
- name: Run s3torchconnector DCP unit tests
69+
run: |
70+
CI_REGION=${{ matrix.test-run.region }} \
71+
CI_BUCKET=${{ matrix.test-run.bucket }} \
72+
CI_STORAGE_CLASS=${{ matrix.test-run.storage-class }} \
73+
pytest s3torchconnector/tst/unit/dcp
74+
6575
lint:
6676
name: Python lints
6777
runs-on: ubuntu-22.04
@@ -89,7 +99,7 @@ jobs:
8999
python -m pip install flake8 black mypy
90100
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
91101
python -m pip install ./s3torchconnectorclient
92-
python -m pip install ./s3torchconnector[lightning]
102+
python -m pip install ./s3torchconnector[lightning,dcp]
93103
94104
- name: Lint with flake8
95105
run: |

.github/workflows/python-integration.yml

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ jobs:
2424
strategy:
2525
fail-fast: false
2626
matrix:
27-
runner: [ubuntu-22.04, macos-13]
28-
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
27+
runner: [ ubuntu-22.04, macos-13 ]
28+
python-version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ]
2929
test-run:
3030
- name: "S3"
3131
bucket: ${{ vars.S3_BUCKET }}
@@ -100,6 +100,18 @@ jobs:
100100
CI_STORAGE_CLASS=${{ matrix.test-run.storage-class }} \
101101
pytest s3torchconnector/tst/e2e/test_e2e_s3_lightning_checkpoint.py -n auto
102102
103+
- name: Install DCP dependencies
104+
if: matrix.runner != 'macos-13'
105+
run: |
106+
python -m pip install './s3torchconnector[dcp-test]'
107+
- name: Run s3torchconnector DCP e2e tests
108+
if: matrix.runner != 'macos-13'
109+
run: |
110+
CI_REGION=${{ matrix.test-run.region }} \
111+
CI_BUCKET=${{ matrix.test-run.bucket }} \
112+
CI_STORAGE_CLASS=${{ matrix.test-run.storage-class }} \
113+
pytest s3torchconnector/tst/e2e/dcp -n auto
114+
103115
- name: s3torchconnectorclient ${{ matrix.test-run.name }} integration tests
104116
run: |
105117
CI_REGION=${{ matrix.test-run.region }} \

s3torchconnector/pyproject.toml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,19 @@ lightning-tests = [
5454
"s3fs"
5555
]
5656

57+
dcp = [
58+
"tenacity",
59+
"torch >= 2.3, < 2.5", # TODO: remove "< 2.5" restriction once https://github.com/pytorch/pytorch/issues/138333 is fixed
60+
]
61+
62+
dcp-test = [
63+
"s3torchconnector[dcp]",
64+
"pytest",
65+
]
66+
5767
[tool.setuptools.packages]
5868
# Pure Python packages/modules
5969
find = { where = ["src"] }
6070

6171
[tool.setuptools]
62-
license-files = [ "LICENSE", "THIRD-PARTY-LICENSES", "NOTICE"]
72+
license-files = ["LICENSE", "THIRD-PARTY-LICENSES", "NOTICE"]
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
4+
from .s3_file_system import S3FileSystem, S3StorageReader, S3StorageWriter
5+
6+
__all__ = [
7+
"S3FileSystem",
8+
"S3StorageReader",
9+
"S3StorageWriter",
10+
]
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
4+
import io
5+
import logging
6+
import os
7+
from contextlib import contextmanager
8+
from pathlib import Path
9+
from typing import Generator, Union, Optional
10+
11+
from s3torchconnectorclient._mountpoint_s3_client import S3Exception
12+
from tenacity import (
13+
retry,
14+
stop_after_attempt,
15+
retry_if_exception_type,
16+
before_sleep_log,
17+
after_log,
18+
wait_random_exponential,
19+
)
20+
from torch.distributed.checkpoint.filesystem import (
21+
FileSystemReader,
22+
FileSystemWriter,
23+
FileSystemBase,
24+
)
25+
26+
from s3torchconnector._s3client import S3Client
27+
from s3torchconnector._s3dataset_common import parse_s3_uri
28+
29+
logger = logging.getLogger(__name__)
30+
31+
32+
class S3FileSystem(FileSystemBase):
33+
def __init__(self, region: str, s3_client: Optional[S3Client] = None) -> None:
34+
self._path: Union[str, os.PathLike] = ""
35+
self._client = s3_client if s3_client is not None else S3Client(region)
36+
37+
@contextmanager
38+
def create_stream(
39+
self, path: Union[str, os.PathLike], mode: str
40+
) -> Generator[io.IOBase, None, None]:
41+
"""
42+
Create a stream for reading or writing to S3.
43+
44+
Args:
45+
path (Union[str, os.PathLike]): The S3 path to read or write.
46+
mode (str): The mode for the stream. Supports 'rb' for read mode and 'wb' for write mode.
47+
48+
Yields:
49+
io.BufferedIOBase: A stream for reading or writing to S3.
50+
51+
Raises:
52+
ValueError: If the mode is not 'rb' or 'wb'.
53+
"""
54+
path_str = _path_or_str_to_str(path)
55+
bucket, key = parse_s3_uri(path_str)
56+
57+
if mode == "wb": # write mode
58+
logger.debug("create_stream writable for %s", path_str)
59+
with self._client.put_object(bucket, key) as stream:
60+
yield stream
61+
elif mode == "rb": # read mode
62+
logger.debug("create_stream readable for %s", path_str)
63+
with self._client.get_object(bucket, key) as stream:
64+
yield stream
65+
else:
66+
raise ValueError(
67+
f"Invalid {mode=} mode argument: create_stream only supports rb (read mode) & wb (write mode)"
68+
)
69+
70+
def concat_path(self, path: Union[str, os.PathLike], suffix: str) -> str:
71+
"""
72+
Concatenate a suffix to the given path.
73+
74+
Args:
75+
path (Union[str, os.PathLike]): The base path.
76+
suffix (str): The suffix to concatenate.
77+
78+
Returns:
79+
str: The concatenated path.
80+
"""
81+
logger.debug("concat paths %s and %s", path, suffix)
82+
path_str = os.fspath(path)
83+
result = os.path.join(path_str, suffix)
84+
return result
85+
86+
def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
87+
"""
88+
Initialize the path for the filesystem.
89+
90+
Args:
91+
path (Union[str, os.PathLike]): The path to initialize.
92+
93+
Returns:
94+
Union[str, os.PathLike]: The initialized path.
95+
"""
96+
logger.debug("init_path for %s", path)
97+
self._path = path
98+
return self._path
99+
100+
def rename(
101+
self, old_path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]
102+
) -> None:
103+
"""Rename an object in S3.
104+
105+
This is emulated by copying it to a new path and deleting the old path. The deletion part is retried (see also
106+
:func:`S3FileSystem._delete_with_retry`).
107+
108+
Args:
109+
old_path (Union[str, os.PathLike]): The current path of the object.
110+
new_path (Union[str, os.PathLike]): The new path for the object.
111+
112+
Raises:
113+
ValueError: If the old and new paths point to different buckets.
114+
S3Exception: If there is an error with the S3 client.
115+
"""
116+
logger.debug("rename %s to %s", old_path, new_path)
117+
118+
old_path_str = _path_or_str_to_str(old_path)
119+
new_path_str = _path_or_str_to_str(new_path)
120+
121+
old_bucket, old_key = parse_s3_uri(old_path_str)
122+
new_bucket, new_key = parse_s3_uri(new_path_str)
123+
124+
if old_bucket != new_bucket:
125+
raise ValueError(
126+
f"Source and destination buckets cannot be different (rename does not support cross-buckets operations)"
127+
)
128+
129+
self._client.copy_object(
130+
src_bucket=old_bucket,
131+
src_key=old_key,
132+
dst_bucket=new_bucket,
133+
dst_key=new_key,
134+
)
135+
logger.debug("rename: copied %s to %s successfully", old_path_str, new_path_str)
136+
self._delete_with_retry(old_bucket, old_key)
137+
logger.debug("rename: s3://%s/%s successfully", old_bucket, old_key)
138+
139+
def mkdir(self, path: Union[str, os.PathLike]) -> None:
140+
"""No-op method for creating directories in S3 (not needed)."""
141+
pass
142+
143+
def exists(self, path: Union[str, os.PathLike]) -> bool:
144+
logger.debug("exists %s", path)
145+
146+
path_str = _path_or_str_to_str(path)
147+
bucket, key = parse_s3_uri(path_str)
148+
try:
149+
self._client.head_object(bucket, key)
150+
except S3Exception as e:
151+
if str(e) != "Service error: The object was not found":
152+
raise
153+
return False
154+
return True
155+
156+
def rm_file(self, path: Union[str, os.PathLike]) -> None:
157+
logger.debug("remove %s", path)
158+
159+
path_str = _path_or_str_to_str(path)
160+
bucket, key = parse_s3_uri(path_str)
161+
try:
162+
self._client.delete_object(bucket, key)
163+
except S3Exception:
164+
logger.exception("Failed to remove object from S3")
165+
166+
@classmethod
167+
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
168+
logger.debug("validate_checkpoint_id for %s", checkpoint_id)
169+
170+
if isinstance(checkpoint_id, Path):
171+
return True
172+
173+
try:
174+
parse_s3_uri(_path_or_str_to_str(checkpoint_id))
175+
except ValueError:
176+
return False
177+
return True
178+
179+
@retry(
180+
retry=retry_if_exception_type(S3Exception),
181+
stop=stop_after_attempt(3),
182+
wait=wait_random_exponential(multiplier=1, max=5),
183+
before_sleep=before_sleep_log(logger, logging.WARNING),
184+
after=after_log(logger, logging.ERROR),
185+
reraise=True,
186+
)
187+
def _delete_with_retry(self, bucket_name: str, old_key: str):
188+
"""Wrapper around :func:`S3Client.delete_object` to retry the deletion.
189+
190+
Will retry a maximum of 3 times, only for `S3Exception`s, and wait between retries. It will reraise the caught
191+
exception too, and logs retries and final error, if any."""
192+
self._client.delete_object(bucket_name, old_key)
193+
194+
195+
class S3StorageWriter(FileSystemWriter):
196+
def __init__(
197+
self,
198+
region: str,
199+
path: Union[str, os.PathLike],
200+
single_file_per_rank: bool = True,
201+
thread_count: int = 1,
202+
per_thread_copy_ahead: int = 10_000_000,
203+
overwrite: bool = False,
204+
) -> None:
205+
"""
206+
Initialize an S3 writer for distributed checkpointing.
207+
208+
Args:
209+
region (str): The AWS region for S3.
210+
path (Union[str, os.PathLike]): The S3 path to write checkpoints.
211+
single_file_per_rank (bool, optional): Whether to write a single file per rank. Defaults to True.
212+
thread_count (int, optional): The number of threads to use for writing. Defaults to 1.
213+
per_thread_copy_ahead (int, optional): The number of bytes to copy ahead per thread. Defaults to 10_000_000.
214+
overwrite (bool, optional): Whether to overwrite existing checkpoints. Defaults to False.
215+
"""
216+
super().__init__(
217+
path=path,
218+
single_file_per_rank=single_file_per_rank,
219+
sync_files=False,
220+
thread_count=thread_count,
221+
per_thread_copy_ahead=per_thread_copy_ahead,
222+
overwrite=overwrite,
223+
)
224+
self.fs = S3FileSystem(region) # type: ignore
225+
self.path = self.fs.init_path(path)
226+
227+
@classmethod
228+
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
229+
return S3FileSystem.validate_checkpoint_id(checkpoint_id)
230+
231+
232+
class S3StorageReader(FileSystemReader):
233+
def __init__(self, region: str, path: Union[str, os.PathLike]) -> None:
234+
"""
235+
Initialize an S3 reader for distributed checkpointing.
236+
237+
Args:
238+
region (str): The AWS region for S3.
239+
path (Union[str, os.PathLike]): The S3 path to read checkpoints from.
240+
"""
241+
super().__init__(path)
242+
self.fs = S3FileSystem(region) # type: ignore
243+
self.path = self.fs.init_path(path)
244+
self.sync_files = False
245+
246+
@classmethod
247+
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
248+
return S3FileSystem.validate_checkpoint_id(checkpoint_id)
249+
250+
251+
def _path_or_str_to_str(path: Union[str, os.PathLike]) -> str:
252+
return path if isinstance(path, str) else str(path)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
4+
import pytest
5+
import torch
6+
import torch.distributed.checkpoint as dcp
7+
from torch.distributed.checkpoint import CheckpointException
8+
9+
from s3torchconnector.dcp import S3StorageWriter, S3StorageReader
10+
11+
12+
def test_fsdp_filesystem_when_single_thread(checkpoint_directory):
13+
# TODO: implement me
14+
pass
15+
16+
17+
def test_fsdp_filesystem_when_multiple_threads(checkpoint_directory):
18+
# TODO: implement me
19+
pass
20+
21+
22+
# Inspired from https://github.com/pytorch/pytorch/blob/main/test/distributed/checkpoint/test_fsspec.py.
23+
def test_overwrite(checkpoint_directory):
24+
t1, t2 = torch.randn(10), torch.randn(10)
25+
region = checkpoint_directory.region
26+
s3_uri = checkpoint_directory.s3_uri
27+
28+
dcp.save(
29+
{"random": t1},
30+
storage_writer=S3StorageWriter(region, s3_uri, overwrite=False),
31+
)
32+
dcp.save(
33+
{"random": t2},
34+
storage_writer=S3StorageWriter(region, s3_uri, overwrite=True),
35+
)
36+
37+
sd = {"random": torch.zeros(10)}
38+
dcp.load(sd, checkpoint_id=s3_uri, storage_reader=S3StorageReader(region, s3_uri))
39+
assert torch.allclose(sd["random"], t2) is True
40+
41+
with pytest.raises(CheckpointException) as excinfo:
42+
dcp.save(
43+
{"random": t2},
44+
storage_writer=S3StorageWriter(region, s3_uri, overwrite=False),
45+
)
46+
47+
assert "Checkpoint already exists" in str(excinfo.value)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD

0 commit comments

Comments
 (0)