Skip to content

Commit c4f4826

Browse files
authored
feat(dcp): make DCPOptimizedS3Reader the default for S3StorageReader (#419)
This commit makes `S3ReaderConstructor.dcp_optimized()` the default reader constructor for DCP loading with `S3StorageReader`, with additional default/zstandard tests and adding a new troubleshooting doc for DCPOptimizedS3Reader troubleshooting, and referred to it in error messages, README, and CHANGELOG. - Change S3StorageReader default from S3ReaderConstructor.default() to S3ReaderConstructor.dcp_optimized() - Adjust docs and error messages - New TROUBLESHOOTING.md doc with DCPOptimizedS3Reader troubleshooting notes. - Add FALLBACK_GUIDANCE to error messages with TROUBLESHOOTING.md doc link and fallback instructions - Update README pointing to TROUBLESHOOTING.md doc in DCP section and simplified DCP examples - Update CHANGELOG with soft breaking change documentation, also pointing to TROUBLESHOOTING.md - Adjust tests - Add unit test for default constructor verification - Zstandard tests (see PR description's additional context for explanation) - Add zstandard to dcp-test dependencies - Add e2e test for ZStandard compression with all reader types
1 parent 4010356 commit c4f4826

File tree

9 files changed

+186
-27
lines changed

9 files changed

+186
-27
lines changed

CHANGELOG.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
## v1.5.0 (February 17, 2026)
1+
## v1.5.0 (February 20, 2026)
22

33
### New features
4-
* Add DCPOptimizedS3Reader for faster and partial DCP loading (#378)
4+
* Add DCPOptimizedS3Reader as new default for faster and partial DCP loading (#378, #419)
55
* Add support for Python 3.14 (#408)
66
* Add weights_only parameter support for Lightning 2.6.0 compatibility (#388)
77

@@ -19,7 +19,7 @@
1919
* Add macOS x86_64 and Python 3.8 deprecation warnings (#400)
2020

2121
### Breaking changes
22-
* No breaking changes.
22+
* No breaking changes, but DCPOptimizedS3Reader as the new default reader for `S3StorageReader` might lead to behavioral changes. See [DCPOptimizedS3Reader Errors](https://github.com/awslabs/s3-connector-for-pytorch/blob/main/docs/TROUBLESHOOTING.md#dcpoptimizeds3reader-errors) for more details.
2323

2424
## v1.4.3 (July 25, 2025)
2525

README.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@ Amazon S3 Connector for PyTorch provides robust support for PyTorch distributed
132132

133133
- `S3StorageReader`: Implementation of PyTorch's StorageReader interface.
134134
- Supports configurable reading strategies via the `reader_constructor` parameter (see [Reader Configurations](#reader-configurations)).
135-
- `S3ReaderConstructor.dcp_optimized()` is recommended for faster loading with partial checkpoint optimizations.
135+
- Uses `DCPOptimizedS3Reader` by default for faster loading and partial checkpoint optimizations.
136+
- Please refer to [DCPOptimizedS3Reader Errors](https://github.com/awslabs/s3-connector-for-pytorch/blob/main/docs/TROUBLESHOOTING.md#dcpoptimizeds3reader-errors) for troubleshooting.
136137
- `S3FileSystem`: An implementation of PyTorch's FileSystemBase.
137138

138139
These tools enable seamless integration of Amazon S3 with
@@ -155,7 +156,6 @@ can be found in the [examples/dcp](https://github.com/awslabs/s3-connector-for-p
155156

156157
```py
157158
from s3torchconnector.dcp import S3StorageWriter, S3StorageReader
158-
from s3torchconnector import S3ReaderConstructor
159159

160160
import torchvision
161161
import torch.distributed.checkpoint as DCP
@@ -178,14 +178,12 @@ DCP.save(
178178
)
179179

180180
# Load distributed checkpoint from S3
181+
# S3StorageReader uses DCPOptimizedS3Reader by default for improved performance
181182
model = torchvision.models.resnet18()
182183
model_state_dict = model.state_dict()
183-
# Use DCP-optimized reader for faster loading
184-
reader_constructor = S3ReaderConstructor.dcp_optimized()
185184
s3_storage_reader = S3StorageReader(
186185
region=REGION,
187186
path=CHECKPOINT_URI,
188-
reader_constructor=reader_constructor, # optional; constructor for S3Reader types
189187
)
190188
DCP.load(
191189
state_dict=model_state_dict,
@@ -424,8 +422,9 @@ Amazon S3 Connector for PyTorch supports three types of readers, configurable th
424422

425423
### Reader Types
426424

427-
#### 1. Sequential Reader (Default)
425+
#### 1. Sequential Reader
428426

427+
- Default for non-DCP use cases.
429428
- Downloads and buffers the entire S3 object in memory.
430429
- Prioritizes performance over memory usage by buffering entire objects.
431430

@@ -437,9 +436,9 @@ Amazon S3 Connector for PyTorch supports three types of readers, configurable th
437436
- **Small reads** (< `buffer_size`): Use internal buffer to reduce S3 API calls.
438437
- **Large reads** (≥ `buffer_size`): Bypass buffer for direct transfer.
439438

440-
#### 3. DCP-Optimized Reader (DCP only)
439+
#### 3. DCP-Optimized Reader
441440

442-
- Specialized usage for PyTorch Distributed Checkpoint (DCP) loading.
441+
- Default for PyTorch Distributed Checkpoint (DCP) loading with `S3StorageReader`.
443442
- Provides performance improvements through per-item buffers and zero-copy buffer management.
444443
- Enables efficient partial checkpoint loading (e.g. model-only) through selective data fetching with range coalescing.
445444
- Automatically handles range metadata injection from DCP load plan.
@@ -449,7 +448,7 @@ Amazon S3 Connector for PyTorch supports three types of readers, configurable th
449448

450449
- **Sequential Reader**: For processing entire objects, and when repeated access to the data is required. Best for most general use cases.
451450
- **Range-based Reader**: For larger objects (100MB+) that require sparse partial reads, and in memory-constrained environments.
452-
- **DCP-Optimized Reader**: For typical PyTorch Distributed Checkpoint loading scenarios for highest performance and memory-efficiency.
451+
- **DCP-Optimized Reader**: For typical PyTorch Distributed Checkpoint loading scenarios for highest performance and memory-efficiency. (Default for `S3StorageReader`)
453452

454453
**Note**: S3Reader instances are not thread-safe and should not be shared across threads. For multiprocessing with DataLoader, each worker process creates its own S3Reader instance automatically.
455454

@@ -484,6 +483,7 @@ DCP interface - `S3StorageReader` usage with dcp-optimized reader:
484483
from s3torchconnector.dcp import S3StorageReader
485484
from s3torchconnector import S3ReaderConstructor
486485

486+
# dcp_optimized is already the default for S3StorageReader; demonstration purposes only.
487487
reader_constructor = S3ReaderConstructor.dcp_optimized()
488488
s3_storage_reader = S3StorageReader(
489489
region=REGION,

docs/TROUBLESHOOTING.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Troubleshooting
2+
3+
If `s3torchconnector` is not working as expected, please check [Github issues](https://github.com/awslabs/s3-connector-for-pytorch/issues) to see if your issue has already been addressed. If not, feel free to [create a GitHub issue](https://github.com/awslabs/s3-connector-for-pytorch/issues/new/choose) with all the details.
4+
5+
For debug logging for mountpoint-s3-client and CRT logs, please refer to [Enabling Debug Logging](https://github.com/awslabs/s3-connector-for-pytorch/blob/main/DEVELOPMENT.md#enabling-debug-logging) section in the DEVELOPMENT doc.
6+
7+
### DCPOptimizedS3Reader Errors
8+
9+
`S3StorageReader` uses `DCPOptimizedS3Reader` (created with `S3ReaderConstructor.dcp_optimized()`) by default (v1.5.0+) for improved performance. See [PR #378](https://github.com/awslabs/s3-connector-for-pytorch/pull/378) for more details about the reader.
10+
11+
If you encounter errors with the default reader, please [submit a GitHub issue](https://github.com/awslabs/s3-connector-for-pytorch/issues) describing your use case. We'd like to understand your scenario and potentially extend `DCPOptimizedS3Reader` to support it, so you can benefit from the performance improvements.
12+
13+
For unsupported or non-DCP access patterns, use the generic reader:
14+
15+
```py
16+
from s3torchconnector import S3ReaderConstructor
17+
from s3torchconnector.dcp import S3StorageReader
18+
19+
storage_reader = S3StorageReader(
20+
region=REGION,
21+
path=CHECKPOINT_URI,
22+
reader_constructor=S3ReaderConstructor.default()
23+
)
24+
```

s3torchconnector/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ dcp = [
6868
dcp-test = [
6969
"s3torchconnector[dcp]",
7070
"pytest",
71+
"zstandard",
7172
]
7273

7374
[tool.setuptools.packages]

s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,15 @@ def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
325325

326326

327327
class S3StorageReader(FileSystemReader):
328-
"""S3 implementation of PyTorch's FileSystemReader with configurable reader strategies."""
328+
"""S3 implementation of PyTorch's FileSystemReader with configurable reader strategies.
329+
330+
By default, uses DCPOptimizedS3Reader for improved checkpoint loading performance.
331+
For unsupported or non-DCP access patterns, please use the generic reader:
332+
storage_reader = S3StorageReader(
333+
region, path,
334+
reader_constructor=S3ReaderConstructor.default()
335+
)
336+
"""
329337

330338
def __init__(
331339
self,
@@ -343,11 +351,14 @@ def __init__(
343351
region (str): The AWS region for S3.
344352
path (Union[str, os.PathLike]): The S3 path to read checkpoints from.
345353
s3client_config (Optional[S3ClientConfig]): Optional S3ClientConfig with parameters for S3 client.
346-
reader_constructor (Optional[S3ReaderConstructorProtocol]): Optional partial(S3Reader) created using S3ReaderConstructor
347-
e.g. S3ReaderConstructor.sequential() or S3ReaderConstructor.range_based()
354+
reader_constructor (Optional[S3ReaderConstructorProtocol]): Reader constructor created using
355+
S3ReaderConstructor. Defaults to ``S3ReaderConstructor.dcp_optimized()`` for best performance.
356+
Use ``S3ReaderConstructor.sequential()`` for unsupported/non-DCP access patterns.
348357
"""
349358
super().__init__(path)
350-
self._reader_constructor = reader_constructor or S3ReaderConstructor.default()
359+
self._reader_constructor = (
360+
reader_constructor or S3ReaderConstructor.dcp_optimized()
361+
)
351362
self.fs: S3FileSystem = S3FileSystem( # type: ignore[assignment] # since we overrode self.fs: FileSystem
352363
region,
353364
s3client_config=s3client_config,

s3torchconnector/src/s3torchconnector/s3reader/constructor.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414
)
1515
from .sequential import SequentialS3Reader
1616
from .ranged import RangedS3Reader
17-
from .dcp_optimized import DCPOptimizedS3Reader, ItemRange, DEFAULT_MAX_GAP_SIZE
17+
from .dcp_optimized import (
18+
DCPOptimizedS3Reader,
19+
ItemRange,
20+
DEFAULT_MAX_GAP_SIZE,
21+
FALLBACK_GUIDANCE,
22+
)
1823

1924
if TYPE_CHECKING:
2025
from torch.distributed.checkpoint.planner import ReadItem
@@ -115,7 +120,8 @@ def __call__(self, bucket: str, key: str, get_object_info, get_stream) -> S3Read
115120

116121
# Error for other files; warn users in case they override prepare_local_plan behavior
117122
raise ValueError(
118-
f"No ranges found for {s3_uri}. Make sure range injection is used in S3StorageReader.prepare_local_plan."
123+
f"No ranges found for {s3_uri}. Make sure range injection is used in "
124+
f"'S3StorageReader.prepare_local_plan'.\n{FALLBACK_GUIDANCE}"
119125
)
120126

121127

@@ -135,7 +141,9 @@ class S3ReaderConstructor:
135141

136142
@staticmethod
137143
def sequential() -> S3ReaderConstructorProtocol:
138-
"""Creates a constructor for sequential readers
144+
"""Creates a constructor for sequential (generic) readers.
145+
146+
This reader is the generic reader that supports all access patterns.
139147
140148
Returns:
141149
S3ReaderConstructorProtocol: Partial constructor for SequentialS3Reader
@@ -158,8 +166,8 @@ def range_based(buffer_size: Optional[int] = None) -> S3ReaderConstructorProtoco
158166
Returns:
159167
S3ReaderConstructorProtocol: Partial constructor for RangedS3Reader
160168
161-
Range-based reader performs byte-range requests to read specific portions of S3 objects without
162-
downloading the entire file.
169+
Range-based reader performs byte-range requests for each read/readinto call
170+
to read specific portions of S3 objects without downloading the entire file.
163171
164172
Buffer size affects read performance:
165173
@@ -233,7 +241,9 @@ def dcp_optimized(
233241

234242
@staticmethod
235243
def default() -> S3ReaderConstructorProtocol:
236-
"""Creates default reader constructor (sequential)
244+
"""Creates the default generic reader constructor.
245+
246+
This creates a sequential (generic) reader that supports all access patterns.
237247
238248
Returns:
239249
S3ReaderConstructorProtocol: Partial constructor for SequentialS3Reader

s3torchconnector/src/s3torchconnector/s3reader/dcp_optimized.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@
4747
FIND_ITEM_ERROR_PREFIX = (
4848
"DCPOptimizedS3Reader only supports sequentially accessing provided ranges: "
4949
)
50+
FALLBACK_GUIDANCE = (
51+
"If this error is encountered with the default DCP reader (S3ReaderConstructor.dcp_optimized()) "
52+
"added in s3torchconnector v1.5.0, please refer to the troubleshooting doc "
53+
"(https://github.com/awslabs/s3-connector-for-pytorch/blob/main/docs/TROUBLESHOOTING.md#dcpoptimizeds3reader-errors)."
54+
"\nFor unsupported or non-DCP access patterns, use the generic reader: "
55+
"S3StorageReader(region, path, reader_constructor=S3ReaderConstructor.default())"
56+
)
5057

5158

5259
@dataclass
@@ -399,7 +406,7 @@ def _find_item_for_range(self, start: int, end: int) -> ItemRange:
399406
if start < item.end or self._current_item_buffer is None:
400407
raise ValueError(
401408
f"{FIND_ITEM_ERROR_PREFIX}Range {start}-{end} not contained in "
402-
f"current item {item.start}-{item.end}"
409+
f"current item {item.start}-{item.end}.\n{FALLBACK_GUIDANCE}"
403410
)
404411

405412
# Advance to next item
@@ -409,7 +416,7 @@ def _find_item_for_range(self, start: int, end: int) -> ItemRange:
409416
except StopIteration:
410417
raise ValueError(
411418
f"{FIND_ITEM_ERROR_PREFIX}Range {start}-{end} not contained in last item "
412-
f"with range {prev_item.start}-{prev_item.end}"
419+
f"with range {prev_item.start}-{prev_item.end}.\n{FALLBACK_GUIDANCE}"
413420
)
414421

415422
# Check if requested range is within next item
@@ -419,7 +426,7 @@ def _find_item_for_range(self, start: int, end: int) -> ItemRange:
419426
raise ValueError(
420427
f"{FIND_ITEM_ERROR_PREFIX}Range {start}-{end} not contained in "
421428
f"current item {prev_item.start}-{prev_item.end} nor the "
422-
f"next item {item.start}-{item.end}."
429+
f"next item {item.start}-{item.end}.\n{FALLBACK_GUIDANCE}"
423430
)
424431

425432
def _get_stream_for_item(self, item: ItemRange) -> GetObjectStream:
@@ -647,11 +654,15 @@ def read(self, size: Optional[int] = None) -> bytes:
647654
S3Exception: An error occurred accessing S3.
648655
"""
649656
if size is None:
650-
raise ValueError("Size cannot be None; full read not supported")
657+
raise ValueError(
658+
f"Size cannot be None; full read not supported.\n{FALLBACK_GUIDANCE}"
659+
)
651660
if not isinstance(size, int):
652661
raise TypeError(f"argument should be integer or None, not {type(size)!r}")
653662
if size < 0:
654-
raise ValueError("Size cannot be negative; full read not supported")
663+
raise ValueError(
664+
f"Size cannot be negative; full read not supported.\n{FALLBACK_GUIDANCE}"
665+
)
655666
if size == 0:
656667
return b""
657668

s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,98 @@ def track_get_object_stream(self, bucket, key, start=None, end=None):
183183
print(
184184
f"{filter_name} load, {coalesce}: {len(stream_calls)} streams, {len(filtered_keys)} tensors"
185185
)
186+
187+
188+
@pytest.mark.parametrize("model", [SIMPLE_MODEL, LARGER_MODEL])
189+
@pytest.mark.parametrize(
190+
"reader_constructor_name,reader_constructor",
191+
[
192+
("sequential", S3ReaderConstructor.sequential()),
193+
("range_based", S3ReaderConstructor.range_based()),
194+
("dcp_optimized", S3ReaderConstructor.dcp_optimized()),
195+
],
196+
)
197+
def test_zstd_compression_partial_load(
198+
checkpoint_directory, model, reader_constructor_name, reader_constructor
199+
):
200+
"""Test ZStandard compression with partial load works for all readers.
201+
202+
Tests compatibility with PyTorch DCP's transform_from() which decompresses
203+
incoming stream data when _extensions=[ZStandard()] is used on S3StorageWriter,
204+
especially testing that it retains sequential access pattern for dcp_optimized reader.
205+
"""
206+
207+
# TODO Python 3.8 uses PyTorch 2.4 and does not have ZStandard; remove conditional import/skip after deprecating Python 3.8.
208+
try:
209+
from torch.distributed.checkpoint._extension import ZStandard
210+
except ImportError:
211+
pytest.skip("ZStandard extension not available in this PyTorch version")
212+
213+
region = checkpoint_directory.region
214+
s3_uri = checkpoint_directory.s3_uri
215+
216+
state_dict = model.state_dict()
217+
all_keys = list(state_dict.keys())
218+
219+
# Save with ZStandard compression
220+
writer = S3StorageWriter(
221+
region=region,
222+
path=s3_uri,
223+
overwrite=True,
224+
_extensions=[ZStandard()],
225+
)
226+
dcp.save(state_dict, storage_writer=writer)
227+
228+
# Partial load - only weight tensors
229+
keys_to_load = [k for k in all_keys if k.endswith(".weight")]
230+
assert keys_to_load, "No weight keys found in model"
231+
loaded = {k: torch.empty_like(state_dict[k]) for k in keys_to_load}
232+
233+
# Track read positions for dcp_optimized
234+
read_calls = []
235+
original_read = DCPOptimizedS3Reader.read
236+
original_readinto = DCPOptimizedS3Reader.readinto
237+
238+
def track_reads(self, size=None):
239+
if not self.key.endswith(".metadata"):
240+
read_calls.append(("read", self._position, size, self.key))
241+
print(f"read: pos={self._position}, size={size}, key={self.key}")
242+
return original_read(self, size)
243+
244+
def track_readinto(self, buf):
245+
if not self.key.endswith(".metadata"):
246+
read_calls.append(("readinto", self._position, len(buf), self.key))
247+
print(f"readinto: pos={self._position}, size={len(buf)}, key={self.key}")
248+
return original_readinto(self, buf)
249+
250+
# Load with position tracking (only affects dcp_optimized)
251+
with (
252+
patch.object(DCPOptimizedS3Reader, "read", track_reads),
253+
patch.object(DCPOptimizedS3Reader, "readinto", track_readinto),
254+
):
255+
reader = S3StorageReader(
256+
region=region,
257+
path=s3_uri,
258+
reader_constructor=reader_constructor,
259+
)
260+
dcp.load(loaded, storage_reader=reader)
261+
262+
# Verify loaded tensors match
263+
for key in keys_to_load:
264+
assert torch.equal(loaded[key], state_dict[key]), f"Mismatch for {key}"
265+
266+
# Print summary and verify sequential access for dcp_optimized
267+
# This helps to manually verify sequential access is still enforced even with
268+
# zstandard transform on each tensor for dcp_optimized reader to work.
269+
if reader_constructor_name == "dcp_optimized" and read_calls:
270+
read_positions = [call[1] for call in read_calls]
271+
assert read_positions == sorted(
272+
read_positions
273+
), "Read positions should be in ascending order"
274+
275+
print(f"\n{reader_constructor_name}: {len(keys_to_load)} tensors loaded")
276+
print(f" Total calls: {len(read_calls)}")
277+
print(f" read: {sum(1 for c in read_calls if c[0] == 'read')}")
278+
print(f" readinto: {sum(1 for c in read_calls if c[0] == 'readinto')}")
279+
else:
280+
print(f"{reader_constructor_name}: {len(keys_to_load)} tensors loaded")

s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from s3torchconnector.dcp import S3StorageReader
1212
from s3torchconnector.s3reader import S3ReaderConstructor, ItemRange
13+
from s3torchconnector.s3reader.constructor import DCPOptimizedConstructor
1314

1415
TEST_REGION = "eu-east-1"
1516
TEST_PATH = "s3://test-bucket/test-checkpoint/"
@@ -34,6 +35,12 @@ def load_plan_with_offsets(draw):
3435
return LoadPlan(items), storage_data
3536

3637

38+
def test_s3storage_reader_default_uses_dcp_optimized():
39+
"""Verify S3StorageReader without explicit constructor uses dcp_optimized."""
40+
reader = S3StorageReader(region=TEST_REGION, path=TEST_PATH)
41+
assert isinstance(reader._reader_constructor, DCPOptimizedConstructor)
42+
43+
3744
def test_s3storage_reader_prepare_local_plan_empty():
3845
"""Test prepare_local_plan handles empty plans."""
3946
s3_storage_reader = S3StorageReader(TEST_REGION, TEST_PATH)

0 commit comments

Comments
 (0)