Skip to content

Commit b6b9593

Browse files
authored
Feature/dcp-partitioning (#327)
Enhance the Distributed Checkpointing (DCP) feature by implementing prefix support, enabling the distribution of checkpoints across multiple prefixes based on worker rank.
1 parent a1702df commit b6b9593

File tree

8 files changed

+684
-18
lines changed

8 files changed

+684
-18
lines changed

README.md

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,134 @@ DCP.load(
171171
model.load_state_dict(model_state_dict)
172172
```
173173

174+
## S3 Prefix Strategies for Distributed Checkpointing
175+
176+
S3StorageWriter implements various prefix strategies to optimize checkpoint organization in S3 buckets.
177+
These strategies are specifically designed to prevent throttling (503 Slow Down errors) in high-throughput scenarios
178+
by implementing S3 key naming best practices as outlined in
179+
[Best practices design patterns: optimizing Amazon S3 performance](https://docs.aws.amazon.com/AmazonS3/latest/userguide/optimizing-performance.html).
180+
181+
When many distributed training processes write checkpoints simultaneously, the prefixing strategies help distribute
182+
the load across multiple S3 partitions.
183+
184+
### Available Strategies
185+
186+
#### 1. RoundRobinPrefixStrategy
187+
Distributes checkpoints across specified prefixes in a round-robin fashion, ideal for balancing data across multiple storage locations.
188+
189+
```python
190+
from s3torchconnector.dcp import RoundRobinPrefixStrategy, S3StorageWriter
191+
192+
model = torchvision.models.resnet18()
193+
194+
# Initialize with multiple prefixes and optional epoch tracking
195+
strategy = RoundRobinPrefixStrategy(
196+
user_prefixes=["shard1", "shard2", "shard3"],
197+
epoch_num=5 # Optional: for checkpoint versioning
198+
)
199+
200+
writer = S3StorageWriter(
201+
region=REGION,
202+
path="CHECKPOINT_URI",
203+
prefix_strategy=strategy
204+
)
205+
206+
# Save checkpoint
207+
DCP.save(
208+
state_dict=model.state_dict(),
209+
storage_writer=writer
210+
)
211+
```
212+
Output Structure:
213+
```
214+
CHECKPOINT_URI
215+
├── shard1/
216+
│ └── epoch_5/
217+
│ ├── __0_0.distcp
218+
│ ├── __3_0.distcp
219+
│ └── ...
220+
├── shard2/
221+
│ └── epoch_5/
222+
│ ├── __1_0.distcp
223+
│ ├── __4_0.distcp
224+
│ └── ...
225+
└── shard3/
226+
└── epoch_5/
227+
├── __2_0.distcp
228+
├── __5_0.distcp
229+
└── ...
230+
```
231+
232+
#### 2. BinaryPrefixStrategy
233+
234+
Generates binary (base-2) prefixes for optimal partitioning in distributed environments.
235+
236+
```python
237+
from s3torchconnector.dcp import BinaryPrefixStrategy
238+
239+
strategy = BinaryPrefixStrategy(
240+
epoch_num=1, # Optional: for checkpoint versioning
241+
min_prefix_len=10 # Optional: minimum prefix length
242+
)
243+
244+
```
245+
Output Structure:
246+
```
247+
s3://my-bucket/checkpoints/
248+
├── 0000000000/
249+
│ └── epoch_1/
250+
│ └── __0_0.distcp
251+
├── 1000000000/
252+
│ └── epoch_1/
253+
│ └── __1_0.distcp
254+
├── 0100000000/
255+
│ └── epoch_1/
256+
│ └── __2_0.distcp
257+
└── ...
258+
```
259+
260+
#### 3. HexPrefixStrategy
261+
262+
Uses hexadecimal (base-16) prefixes for a balance of efficiency and readability.
263+
```
264+
from s3torchconnector.dcp import HexPrefixStrategy
265+
266+
strategy = HexPrefixStrategy(
267+
epoch_num=1, # Optional: for checkpoint versioning
268+
min_prefix_len=4 # Optional: minimum prefix length
269+
)
270+
```
271+
Output Structure:
272+
```
273+
s3://my-bucket/checkpoints/
274+
├── 0000/
275+
│ └── epoch_1/
276+
│ └── __0_0.distcp
277+
├── 1000/
278+
│ └── epoch_1/
279+
│ └── __1_0.distcp
280+
...
281+
├── f000/
282+
│ └── epoch_1/
283+
│ └── __15_0.distcp
284+
└── ...
285+
```
286+
287+
### Creating Custom Strategies
288+
289+
You can implement custom prefix strategies by extending the S3PrefixStrategyBase class:
290+
```
291+
from s3torchconnector.dcp import S3PrefixStrategyBase
292+
293+
class CustomPrefixStrategy(S3PrefixStrategyBase):
294+
def __init__(self, custom_param):
295+
super().__init__()
296+
self.custom_param = custom_param
297+
298+
def generate_prefix(self, rank: int) -> str:
299+
return f"custom_{self.custom_param}/{rank}/"
300+
```
301+
174302
## Parallel/Distributed Training
175303

176304
Amazon S3 Connector for PyTorch provides support for parallel and distributed training with PyTorch,

examples/dcp/stateful_example.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from torch.distributed.device_mesh import init_device_mesh
1818
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1919

20+
from s3torchconnector import S3ClientConfig
2021
from s3torchconnector.dcp import S3StorageWriter, S3StorageReader
22+
from s3torchconnector.dcp.s3_prefix_strategy import RoundRobinPrefixStrategy
2123

2224

2325
class Model(torch.nn.Module):
@@ -98,6 +100,10 @@ def _setup(rank, world_size):
98100
torch.cuda.set_device(rank)
99101

100102

103+
def _cleanup():
104+
dist.destroy_process_group()
105+
106+
101107
def _train_initial_model(device, rank, world_size):
102108
print(f"Train initial model on rank:{rank}")
103109
model, optim = _init_model(device, world_size)
@@ -126,8 +132,34 @@ def run(rank, world_size, region, s3_uri, device="cuda"):
126132
model, optim = _train_initial_model(device, rank, world_size)
127133

128134
print(f"Saving checkpoint on rank:{rank}")
129-
# initialize S3StorageWriter with region and bucket name, before passing to dcp.save as writer
130-
storage_writer = S3StorageWriter(region, s3_uri)
135+
# S3ClientConfig configuration for optimized data transfer to S3
136+
s3config = S3ClientConfig(
137+
# Sets the size of each part in multipart upload to 16MB (16 * 1024 * 1024 bytes)
138+
# This is a reasonable default for large file transfers
139+
part_size=16 * 1024 * 1024,
140+
# Targets a throughput of 600 Gbps for data transfer
141+
# Suitable for high-bandwidth environments (P5/trn1 instances) and large model transfers
142+
throughput_target_gbps=600,
143+
# Maximum number of retry attempts for failed operations
144+
# Helps handle transient network issues or S3 throttling
145+
max_attempts=20,
146+
)
147+
148+
# RoundRobinPrefixStrategy distributes checkpoint data across multiple prefixes in a round-robin fashion
149+
strategy = RoundRobinPrefixStrategy(
150+
# List of prefix strings that will be used in rotation for storing checkpoint shards
151+
# Each prefix represents a separate "path" in S3 where checkpoint data will be stored
152+
# Using multiple prefixes helps with lowering TPS per prefix
153+
user_prefixes=["0000000000", "1000000000", "0100000000", "1100000000"],
154+
# Optional integer for versioning checkpoints across training epochs
155+
# If provided, will append epoch number to prefix paths
156+
# Helps track checkpoint evolution over training progress
157+
epoch_num=5, # Optional: for checkpoint versioning
158+
)
159+
# initialize S3StorageWriter with region, bucket name and s3config, before passing to dcp.save as writer
160+
storage_writer = S3StorageWriter(
161+
region=region, path=s3_uri, s3client_config=s3config, prefix_strategy=strategy
162+
)
131163
dcp.save(
132164
state_dict={"model": model, "optimizer": optim},
133165
storage_writer=storage_writer,
@@ -139,13 +171,16 @@ def run(rank, world_size, region, s3_uri, device="cuda"):
139171
)
140172
print(f"Load previously saved checkpoint on rank:{rank}")
141173
# initialize S3StorageReader with region and bucket name, before passing to dcp.load as reader
142-
storage_reader = S3StorageReader(region, s3_uri)
174+
storage_reader = S3StorageReader(
175+
region=region, path=s3_uri, s3client_config=s3config
176+
)
143177
dcp.load(
144178
state_dict={"model": modified_model, "optimizer": modified_optim},
145179
storage_reader=storage_reader,
146180
)
147181
_continue_training_loaded_model(modified_model, modified_optim, model, rank)
148182
print(f"Quiting on rank:{rank}")
183+
_cleanup()
149184

150185

151186
if __name__ == "__main__":

s3torchconnector/src/s3torchconnector/dcp/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,21 @@
22
# // SPDX-License-Identifier: BSD
33

44
from .s3_file_system import S3FileSystem, S3StorageReader, S3StorageWriter
5+
from .s3_prefix_strategy import (
6+
S3PrefixStrategyBase,
7+
DefaultPrefixStrategy,
8+
NumericPrefixStrategy,
9+
BinaryPrefixStrategy,
10+
HexPrefixStrategy,
11+
)
512

613
__all__ = [
714
"S3FileSystem",
815
"S3StorageReader",
916
"S3StorageWriter",
17+
"S3PrefixStrategyBase",
18+
"DefaultPrefixStrategy",
19+
"NumericPrefixStrategy",
20+
"BinaryPrefixStrategy",
21+
"HexPrefixStrategy",
1022
]

s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from contextlib import contextmanager
99
from pathlib import Path
1010
from typing import Generator, Union, Optional
11+
from typing import List
1112

1213
from s3torchconnectorclient._mountpoint_s3_client import S3Exception
1314
from tenacity import (
@@ -28,6 +29,7 @@
2829
from s3torchconnector._s3client import S3Client
2930
from s3torchconnector._s3dataset_common import parse_s3_uri
3031
from .. import S3ClientConfig
32+
from .s3_prefix_strategy import S3PrefixStrategyBase, DefaultPrefixStrategy
3133
from .._user_agent import UserAgent
3234

3335
logger = logging.getLogger(__name__)
@@ -43,11 +45,11 @@ def __init__(
4345
self._path: Union[str, os.PathLike] = ""
4446
user_agent = UserAgent(["dcp", torch.__version__])
4547
self._client = (
46-
s3_client
47-
if s3_client is not None
48-
else S3Client(
48+
S3Client(
4949
region=region, user_agent=user_agent, s3client_config=s3client_config
5050
)
51+
if s3_client is None
52+
else s3_client
5153
)
5254

5355
@contextmanager
@@ -227,12 +229,25 @@ def _escape_path(string):
227229
return "/".join(parts)
228230

229231

232+
from torch.distributed.checkpoint.planner import SavePlan
233+
import dataclasses
234+
from dataclasses import dataclass
235+
236+
237+
@dataclass
238+
class StorageMetadata:
239+
"""Metadata for S3 storage prefix."""
240+
241+
prefix: str
242+
243+
230244
class S3StorageWriter(FileSystemWriter):
231245
def __init__(
232246
self,
233247
region: str,
234248
path: str,
235249
s3client_config: Optional[S3ClientConfig] = None,
250+
prefix_strategy: Optional[S3PrefixStrategyBase] = None,
236251
**kwargs,
237252
) -> None:
238253
"""
@@ -241,6 +256,7 @@ def __init__(
241256
Args:
242257
region (str): The AWS region for S3.
243258
path (str): The S3 URI to write checkpoints to.
259+
prefix_strategy: Strategy for generating S3 prefixes.
244260
kwargs (dict): Keyword arguments to pass to the parent :class:`FileSystemWriter`.
245261
"""
246262
super().__init__(
@@ -250,6 +266,24 @@ def __init__(
250266
)
251267
self.fs = S3FileSystem(region, s3client_config=s3client_config) # type: ignore
252268
self.path = self.fs.init_path(path)
269+
self.prefix_strategy = prefix_strategy or DefaultPrefixStrategy()
270+
271+
def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
272+
"""
273+
Prepare save plans with S3-specific storage metadata.
274+
275+
Args:
276+
plans: List of save plans to be processed.
277+
278+
Returns:
279+
Modified save plans with S3 storage metadata.
280+
"""
281+
return [
282+
dataclasses.replace(
283+
plan, storage_data=StorageMetadata(self.prefix_strategy(idx))
284+
)
285+
for idx, plan in enumerate(plans)
286+
]
253287

254288
@classmethod
255289
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:

0 commit comments

Comments
 (0)