Skip to content

Commit ae458f2

Browse files
committed
Add support for distributed training to S3IterableDataset (#243)
Add support of multi-process/multi-node sharding to s3iterable
1 parent 2123dfe commit ae458f2

File tree

8 files changed

+533
-35
lines changed

8 files changed

+533
-35
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,18 @@
1+
## v1.x.x (TBD)
2+
* Add support of distributed training to S3IterableDataset
3+
4+
### Breaking changes
5+
* No breaking changes.
6+
17
## v1.2.7 (October 29, 2024)
28

39
### New features
410
* Add support for CRT retries (awslabs/mountpoint-s3#1069).
511
* Add support for `CopyObject` API (#242).
612

13+
### Breaking changes
14+
* No breaking changes.
15+
716
## v1.2.6 (October 9, 2024)
817

918
### New features

README.md

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Amazon S3, without first saving to local storage.
2828
pip install s3torchconnector
2929
```
3030

31-
Amazon S3 Connector for PyTorch supports only Linux via Pip for now. For other platforms,
31+
Amazon S3 Connector for PyTorch supports pre-build wheels via Pip only for Linux and MacOS for now. For other platforms,
3232
see [DEVELOPMENT](DEVELOPMENT.md) for build instructions.
3333

3434
### Configuration
@@ -114,7 +114,35 @@ For example, assuming the following directory bucket name `my-test-bucket--usw2-
114114
usw2-az1, then the URI used will look like: `s3://my-test-bucket--usw2-az1--x-s3/<PREFIX>` (**please note that the
115115
prefix for Amazon S3 Express One Zone should end with '/'**), paired with region us-west-2.
116116

117+
## Parallel/Distributed Training
117118

119+
Amazon S3 Connector for PyTorch provides support for parallel and distributed training with PyTorch,
120+
allowing you to leverage multiple processes and nodes for efficient data loading and training.
121+
Both S3IterableDataset and S3MapDataset can be used for this purpose.
122+
123+
### S3IterableDataset
124+
125+
The S3IterableDataset can be directly passed to PyTorch's DataLoader for parallel and distributed training.
126+
By default, all worker processes will share the same list of training objects. However,
127+
if you need each worker to have access to a unique portion of the dataset for better parallelization,
128+
you can enable dataset sharding using the `enable_sharding` parameter.
129+
```
130+
dataset = S3IterableDataset.from_prefix(DATASET_URI, region=REGION, enable_sharding=True)
131+
dataloader = DataLoader(dataset, num_workers=4)
132+
```
133+
When `enable_sharding` is set to True, the dataset will be automatically sharded across available number of workers.
134+
This sharding mechanism supports both parallel training on a single host and distributed training across multiple hosts.
135+
Each worker, regardless of its host, will load and process a distinct subset of the dataset.
136+
### S3MapDataset
137+
138+
For the S3MapDataset, you need to pass it to DataLoader along with a [DistributedSampler](https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler) wrapped around it.
139+
The DistributedSampler ensures that each worker or node receives a unique subset of the dataset,
140+
enabling efficient parallel and distributed training.
141+
```
142+
dataset = S3MapDataset.from_prefix(DATASET_URI, region=REGION)
143+
sampler = DistributedSampler(dataset)
144+
dataloader = DataLoader(dataset, sampler=sampler, num_workers=4)
145+
```
118146
## Lightning Integration
119147

120148
Amazon S3 Connector for PyTorch includes an integration for PyTorch Lightning, featuring S3LightningCheckpoint, an
@@ -183,3 +211,4 @@ See [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md) for more details.
183211
## License
184212

185213
Amazon S3 Connector for PyTorch has a BSD 3-Clause License, as found in the [LICENSE](LICENSE) file.
214+

s3torchconnector/src/s3torchconnector/s3iterable_dataset.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66

77
import torch.utils.data
8+
import torch
89

910
from . import S3Reader
1011
from ._s3bucket_key_data import S3BucketKeyData
@@ -32,13 +33,21 @@ def __init__(
3233
endpoint: Optional[str] = None,
3334
transform: Callable[[S3Reader], Any] = identity,
3435
s3client_config: Optional[S3ClientConfig] = None,
36+
enable_sharding: bool = False,
3537
):
3638
self._get_dataset_objects = get_dataset_objects
3739
self._transform = transform
3840
self._region = region
3941
self._endpoint = endpoint
4042
self._s3client_config = s3client_config
4143
self._client = None
44+
self._enable_sharding = enable_sharding
45+
46+
self._rank = 0
47+
self._world_size = 1
48+
if torch.distributed.is_initialized():
49+
self._rank = torch.distributed.get_rank()
50+
self._world_size = torch.distributed.get_world_size()
4251

4352
@property
4453
def region(self):
@@ -57,6 +66,7 @@ def from_objects(
5766
endpoint: Optional[str] = None,
5867
transform: Callable[[S3Reader], Any] = identity,
5968
s3client_config: Optional[S3ClientConfig] = None,
69+
enable_sharding: bool = False,
6070
):
6171
"""Returns an instance of S3IterableDataset using the S3 URI(s) provided.
6272
@@ -66,6 +76,7 @@ def from_objects(
6676
endpoint(str): AWS endpoint of the S3 bucket where the objects are stored.
6777
transform: Optional callable which is used to transform an S3Reader into the desired type.
6878
s3client_config: Optional S3ClientConfig with parameters for S3 client.
79+
enable_sharding: If True, shard the dataset across multiple workers for parallel data loading. If False (default), each worker loads the entire dataset independently.
6980
7081
Returns:
7182
S3IterableDataset: An IterableStyle dataset created from S3 objects.
@@ -80,6 +91,7 @@ def from_objects(
8091
endpoint,
8192
transform=transform,
8293
s3client_config=s3client_config,
94+
enable_sharding=enable_sharding,
8395
)
8496

8597
@classmethod
@@ -91,6 +103,7 @@ def from_prefix(
91103
endpoint: Optional[str] = None,
92104
transform: Callable[[S3Reader], Any] = identity,
93105
s3client_config: Optional[S3ClientConfig] = None,
106+
enable_sharding: bool = False,
94107
):
95108
"""Returns an instance of S3IterableDataset using the S3 URI provided.
96109
@@ -100,6 +113,7 @@ def from_prefix(
100113
endpoint(str): AWS endpoint of the S3 bucket where the objects are stored.
101114
transform: Optional callable which is used to transform an S3Reader into the desired type.
102115
s3client_config: Optional S3ClientConfig with parameters for S3 client.
116+
enable_sharding: If True, shard the dataset across multiple workers for parallel data loading. If False (default), each worker loads the entire dataset independently.
103117
104118
Returns:
105119
S3IterableDataset: An IterableStyle dataset created from S3 objects.
@@ -114,6 +128,7 @@ def from_prefix(
114128
endpoint,
115129
transform=transform,
116130
s3client_config=s3client_config,
131+
enable_sharding=enable_sharding,
117132
)
118133

119134
def _get_client(self):
@@ -133,6 +148,47 @@ def _get_transformed_object(self, bucket_key: S3BucketKeyData) -> Any:
133148
)
134149

135150
def __iter__(self) -> Iterator[Any]:
136-
return map(
137-
self._get_transformed_object, self._get_dataset_objects(self._get_client())
151+
worker_id = 0
152+
num_workers = 1
153+
if self._enable_sharding:
154+
worker_info = torch.utils.data.get_worker_info()
155+
if worker_info is not None:
156+
worker_id = worker_info.id
157+
num_workers = worker_info.num_workers
158+
159+
if not self._enable_sharding or (self._world_size == 1 and num_workers == 1):
160+
# sharding disabled or only one shard is available, so return the entire dataset
161+
return map(
162+
self._get_transformed_object,
163+
self._get_dataset_objects(self._get_client()),
164+
)
165+
166+
"""In a multi-process setting (e.g., distributed training), the dataset needs to be
167+
sharded across multiple processes. The following variables control this sharding:
168+
169+
_rank: The rank (index) of the current process within the world (group of processes).
170+
_world_size: The total number of processes in the world (group).
171+
172+
In addition, within each process, the dataset may be further sharded across multiple
173+
worker threads or processes (e.g., for data loading). The following variables control
174+
this intra-process sharding:
175+
176+
worker_id: The ID of the current worker thread/process within the process.
177+
num_workers: The total number of worker threads/processes within the process.
178+
"""
179+
180+
# First, distribute objects across ranks
181+
rank_sharded_objects = (
182+
obj
183+
for idx, obj in enumerate(self._get_dataset_objects(self._get_client()))
184+
if idx % self._world_size == self._rank
138185
)
186+
187+
# Then, distribute objects within each rank across workers
188+
worker_sharded_objects = (
189+
obj
190+
for idx, obj in enumerate(rank_sharded_objects)
191+
if idx % num_workers == worker_id
192+
)
193+
194+
return map(self._get_transformed_object, worker_sharded_objects)

s3torchconnector/tst/e2e/conftest.py

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,41 +18,67 @@ def getenv(var: str, optional: bool = False) -> str:
1818
return v
1919

2020

21-
class BucketPrefixFixture(object):
21+
class BucketPrefixData(object):
2222
"""An S3 bucket/prefix and its contents for use in a single unit test. The prefix will be unique
2323
to this instance, so other concurrent tests won't affect its state."""
2424

2525
region: str
2626
bucket: str
2727
prefix: str
2828
storage_class: str = None
29+
contents: dict
2930

3031
def __init__(
31-
self, region: str, bucket: str, prefix: str, storage_class: str = None
32+
self,
33+
region: str,
34+
bucket: str,
35+
prefix: str,
36+
storage_class: str = None,
37+
contents: dict = None,
3238
):
3339
self.bucket = bucket
3440
self.prefix = prefix
3541
self.region = region
3642
self.storage_class = storage_class
37-
self.contents = {}
38-
session = boto3.Session(region_name=region)
39-
self.s3 = session.client("s3")
43+
self.contents = contents or {}
4044

4145
@property
4246
def s3_uri(self):
4347
return f"s3://{self.bucket}/{self.prefix}"
4448

49+
def __getitem__(self, index):
50+
return self.contents[index]
51+
52+
def __iter__(self):
53+
return iter(self.contents)
54+
55+
56+
class BucketPrefixFixture(BucketPrefixData):
57+
"""An S3 bucket/prefix and its contents for use in a single unit test. The prefix will be unique
58+
to this instance, so other concurrent tests won't affect its state."""
59+
60+
def __init__(
61+
self, region: str, bucket: str, prefix: str, storage_class: str = None
62+
):
63+
super().__init__(region, bucket, prefix, storage_class)
64+
session = boto3.Session(region_name=region)
65+
self.s3 = session.client("s3")
66+
4567
def add(self, key: str, contents: bytes, **kwargs):
4668
"""Upload an S3 object to this prefix of the bucket."""
4769
full_key = f"{self.prefix}{key}"
4870
self.s3.put_object(Bucket=self.bucket, Key=full_key, Body=contents, **kwargs)
4971
self.contents[full_key] = contents
5072

51-
def __getitem__(self, index):
52-
return self.contents[index]
73+
def get_data_snapshot(self):
74+
"""Returns a read-only copy of the current instance's data.
5375
54-
def __iter__(self):
55-
return iter(self.contents)
76+
The returned object cannot modify the actual S3 bucket.
77+
Useful when passing data to another process without serializing s3 client
78+
"""
79+
return BucketPrefixData(
80+
self.region, self.bucket, self.prefix, self.storage_class, self.contents
81+
)
5682

5783

5884
def get_test_bucket_prefix(name: str) -> BucketPrefixFixture:
@@ -71,13 +97,30 @@ def get_test_bucket_prefix(name: str) -> BucketPrefixFixture:
7197

7298
@pytest.fixture
7399
def image_directory(request) -> BucketPrefixFixture:
74-
"""Create a bucket/prefix fixture that contains a directory of random JPG image files."""
75100
NUM_IMAGES = 10
76101
IMAGE_SIZE = 100
77-
fixture = get_test_bucket_prefix(f"{request.node.name}/image_directory")
78-
for i in range(NUM_IMAGES):
79-
data = np.random.randint(0, 256, IMAGE_SIZE * IMAGE_SIZE * 3, np.uint8)
80-
data = data.reshape(IMAGE_SIZE, IMAGE_SIZE, 3)
102+
return _create_image_directory_fixture(NUM_IMAGES, IMAGE_SIZE, request.node.name)
103+
104+
105+
@pytest.fixture
106+
def image_directory_for_dp(request) -> BucketPrefixFixture:
107+
"""When conducting distributed training tests, be cautious about the number of files (images) in the test dataset.
108+
If the total number of images cannot be evenly divided by the number of workers,
109+
the DistributedSampler will duplicate a subset of the images across workers to ensure an equal
110+
distribution of data among all processes. This duplication of images will cause
111+
integration distributed training test to fail.
112+
"""
113+
NUM_IMAGES = 36
114+
IMAGE_SIZE = 100
115+
return _create_image_directory_fixture(NUM_IMAGES, IMAGE_SIZE, request.node.name)
116+
117+
118+
def _create_image_directory_fixture(num_image: int, image_size: int, node_name: str):
119+
"""Create a bucket/prefix fixture that contains a directory of random JPG image files."""
120+
fixture = get_test_bucket_prefix(f"{node_name}/image_directory")
121+
for i in range(num_image):
122+
data = np.random.randint(0, 256, image_size * image_size * 3, np.uint8)
123+
data = data.reshape(image_size, image_size, 3)
81124
image = Image.fromarray(data, "RGB")
82125
image_bytes = io.BytesIO()
83126
image.save(image_bytes, "jpeg")
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
4+
import platform
5+
import torch
6+
from s3torchconnector import S3Reader
7+
8+
from typing import Tuple, List
9+
10+
11+
def _get_fork_methods() -> List[str]:
12+
"""Get a set of valid start methods for PyTorch's multiprocessing.
13+
On macOS, the 'fork' and 'forkserver' start methods are known to crash,
14+
despite being reported as usable by PyTorch. This function filters out
15+
those methods for macOS systems.
16+
17+
Returns:
18+
List[str]: A set of valid start methods for the current platform.
19+
"""
20+
methods = set(torch.multiprocessing.get_all_start_methods())
21+
22+
if platform.system() == "Darwin":
23+
# fork and forkserver crash on MacOS, even though it's reported as usable.
24+
# https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
25+
# https://bugs.python.org/issue?@action=redirect&bpo=33725
26+
methods -= {"fork", "forkserver"}
27+
return [method for method in methods]
28+
29+
30+
def _set_start_method(start_method: str):
31+
torch.multiprocessing.set_start_method(start_method, force=True)
32+
33+
34+
def _read_data(s3reader: S3Reader) -> Tuple[str, bytes]:
35+
return s3reader.key, s3reader.read()

0 commit comments

Comments
 (0)