Skip to content

Commit b82b0a6

Browse files
authored
Expose S3ClientConfig (#191)
* Expose S3ClientConfig We expose the following configuration flags with performance impact: throughput_target_gbps(float): Throughput target in Gigabits per second (Gbps). part_size(int): Size, in bytes, of parts that files will be downloaded or uploaded in.
1 parent 2da3601 commit b82b0a6

File tree

14 files changed

+195
-19
lines changed

14 files changed

+195
-19
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
## Unreleased
2+
3+
### New features
4+
* Expose a new class, S3ClientConfig, with `throughput_target_gbps` and `part_size` parameters of the inner S3 client.
5+
16
## v1.2.1 (March 14, 2024)
27

38
### Breaking changes

s3torchbenchmarking/src/s3torchbenchmarking/benchmark_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def throughput(self):
7575

7676

7777
class ExperimentResultJsonEncoder(JSONEncoder):
78-
7978
def default(self, o: Any) -> Any:
8079
if isinstance(o, ExperimentResult):
8180
o: ExperimentResult = o

s3torchconnector/src/s3torchconnector/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .s3map_dataset import S3MapDataset
1212
from .s3checkpoint import S3Checkpoint
1313
from ._version import __version__
14+
from ._s3client import S3ClientConfig
1415

1516
__all__ = [
1617
"S3IterableDataset",
@@ -19,5 +20,6 @@
1920
"S3Reader",
2021
"S3Writer",
2122
"S3Exception",
23+
"S3ClientConfig",
2224
"__version__",
2325
]
Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# // SPDX-License-Identifier: BSD
33

4+
from .s3client_config import S3ClientConfig
45
from ._s3client import S3Client
56
from ._mock_s3client import MockS3Client
67

7-
__all__ = ["S3Client", "MockS3Client"]
8+
__all__ = [
9+
"S3ClientConfig",
10+
"S3Client",
11+
"MockS3Client",
12+
]

s3torchconnector/src/s3torchconnector/_s3client/_mock_s3client.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from . import S3Client
1212
from .._user_agent import UserAgent
13+
from .s3client_config import S3ClientConfig
1314

1415
"""
1516
_mock_s3client.py
@@ -22,14 +23,19 @@ def __init__(
2223
self,
2324
region: str,
2425
bucket: str,
25-
part_size: int = 8 * 1024 * 1024,
2626
user_agent: Optional[UserAgent] = None,
27+
s3client_config: Optional[S3ClientConfig] = None,
2728
):
28-
super().__init__(region, user_agent=user_agent)
29+
super().__init__(
30+
region,
31+
user_agent=user_agent,
32+
s3client_config=s3client_config,
33+
)
2934
self._mock_client = MockMountpointS3Client(
3035
region,
3136
bucket,
32-
part_size=part_size,
37+
throughput_target_gbps=self.s3client_config.throughput_target_gbps,
38+
part_size=self.s3client_config.part_size,
3339
user_agent_prefix=self.user_agent_prefix,
3440
)
3541

s3torchconnector/src/s3torchconnector/_s3client/_s3client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Optional, Any
88

99
from s3torchconnector import S3Reader, S3Writer
10+
from .s3client_config import S3ClientConfig
1011

1112
from s3torchconnectorclient._mountpoint_s3_client import (
1213
MountpointS3Client,
@@ -35,15 +36,18 @@ class S3Client:
3536
def __init__(
3637
self,
3738
region: str,
39+
*,
3840
endpoint: Optional[str] = None,
3941
user_agent: Optional[UserAgent] = None,
42+
s3client_config: Optional[S3ClientConfig] = None,
4043
):
4144
self._region = region
4245
self._endpoint = endpoint
4346
self._real_client: Optional[MountpointS3Client] = None
4447
self._client_pid: Optional[int] = None
4548
user_agent = user_agent or UserAgent()
4649
self._user_agent_prefix = user_agent.prefix
50+
self._s3client_config = s3client_config or S3ClientConfig()
4751

4852
@property
4953
def _client(self) -> MountpointS3Client:
@@ -58,6 +62,10 @@ def _client(self) -> MountpointS3Client:
5862
def region(self) -> str:
5963
return self._region
6064

65+
@property
66+
def s3client_config(self) -> S3ClientConfig:
67+
return self._s3client_config
68+
6169
@property
6270
def user_agent_prefix(self) -> str:
6371
return self._user_agent_prefix
@@ -67,6 +75,8 @@ def _client_builder(self) -> MountpointS3Client:
6775
region=self._region,
6876
endpoint=self._endpoint,
6977
user_agent_prefix=self._user_agent_prefix,
78+
throughput_target_gbps=self._s3client_config.throughput_target_gbps,
79+
part_size=self._s3client_config.part_size,
7080
)
7181

7282
def get_object(
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
from dataclasses import dataclass
4+
5+
6+
@dataclass(frozen=True)
7+
class S3ClientConfig:
8+
"""A dataclass exposing configurable parameters for the S3 client.
9+
10+
Args:
11+
throughput_target_gbps(float): Throughput target in Gigabits per second (Gbps) that we are trying to reach.
12+
10.0 Gbps by default (may change in future).
13+
part_size(int): Size, in bytes, of parts that files will be downloaded or uploaded in.
14+
Note: for saving checkpoints, the inner client will adjust the part size to meet the service limits.
15+
(max number of parts per upload is 10,000, minimum upload part size is 5 MiB).
16+
Part size must have values between 5MiB and 5GiB.
17+
8MB by default (may change in future).
18+
"""
19+
20+
throughput_target_gbps: float = 10.0
21+
part_size: int = 8 * 1024 * 1024

s3torchconnector/src/s3torchconnector/lightning/s3_lightning_checkpoint.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,26 @@
88

99
from lightning.pytorch.plugins.io import CheckpointIO
1010

11-
from .._s3client import S3Client
11+
from .._s3client import S3Client, S3ClientConfig
1212
from .._s3dataset_common import parse_s3_uri
1313
from .._user_agent import UserAgent
1414

1515

1616
class S3LightningCheckpoint(CheckpointIO):
1717
"""A checkpoint manager for S3 using the :class:`CheckpointIO` interface."""
1818

19-
def __init__(self, region: str):
19+
def __init__(
20+
self,
21+
region: str,
22+
s3client_config: Optional[S3ClientConfig] = None,
23+
):
2024
self.region = region
2125
user_agent = UserAgent(["lightning", lightning.__version__])
22-
self._client = S3Client(region, user_agent=user_agent)
26+
self._client = S3Client(
27+
region,
28+
user_agent=user_agent,
29+
s3client_config=s3client_config,
30+
)
2331

2432
def save_checkpoint(
2533
self,

s3torchconnector/src/s3torchconnector/s3checkpoint.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Optional
44

55
from ._s3dataset_common import parse_s3_uri
6-
from ._s3client import S3Client
6+
from ._s3client import S3Client, S3ClientConfig
77
from . import S3Reader, S3Writer
88

99

@@ -17,10 +17,17 @@ class S3Checkpoint:
1717
torch.load, and torch.save.
1818
"""
1919

20-
def __init__(self, region: str, endpoint: Optional[str] = None):
20+
def __init__(
21+
self,
22+
region: str,
23+
endpoint: Optional[str] = None,
24+
s3client_config: Optional[S3ClientConfig] = None,
25+
):
2126
self.region = region
2227
self.endpoint = endpoint
23-
self._client = S3Client(region, endpoint)
28+
self._client = S3Client(
29+
region, endpoint=endpoint, s3client_config=s3client_config
30+
)
2431

2532
def reader(self, s3_uri: str) -> S3Reader:
2633
"""Creates an S3Reader from a given s3_uri.

s3torchconnector/src/s3torchconnector/s3iterable_dataset.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from . import S3Reader
1010
from ._s3bucket_key_data import S3BucketKeyData
11-
from ._s3client import S3Client
11+
from ._s3client import S3Client, S3ClientConfig
1212
from ._s3dataset_common import (
1313
identity,
1414
get_objects_from_uris,
@@ -31,11 +31,13 @@ def __init__(
3131
get_dataset_objects: Callable[[S3Client], Iterable[S3BucketKeyData]],
3232
endpoint: Optional[str] = None,
3333
transform: Callable[[S3Reader], Any] = identity,
34+
s3client_config: Optional[S3ClientConfig] = None,
3435
):
3536
self._get_dataset_objects = get_dataset_objects
3637
self._transform = transform
3738
self._region = region
3839
self._endpoint = endpoint
40+
self._s3client_config = s3client_config
3941
self._client = None
4042

4143
@property
@@ -54,6 +56,7 @@ def from_objects(
5456
region: str,
5557
endpoint: Optional[str] = None,
5658
transform: Callable[[S3Reader], Any] = identity,
59+
s3client_config: Optional[S3ClientConfig] = None,
5760
):
5861
"""Returns an instance of S3IterableDataset using the S3 URI(s) provided.
5962
@@ -62,6 +65,7 @@ def from_objects(
6265
region(str): AWS region of the S3 bucket where the objects are stored.
6366
endpoint(str): AWS endpoint of the S3 bucket where the objects are stored.
6467
transform: Optional callable which is used to transform an S3Reader into the desired type.
68+
s3client_config: Optional S3ClientConfig with parameters for S3 client.
6569
6670
Returns:
6771
S3IterableDataset: An IterableStyle dataset created from S3 objects.
@@ -75,6 +79,7 @@ def from_objects(
7579
partial(get_objects_from_uris, object_uris),
7680
endpoint,
7781
transform=transform,
82+
s3client_config=s3client_config,
7883
)
7984

8085
@classmethod
@@ -85,6 +90,7 @@ def from_prefix(
8590
region: str,
8691
endpoint: Optional[str] = None,
8792
transform: Callable[[S3Reader], Any] = identity,
93+
s3client_config: Optional[S3ClientConfig] = None,
8894
):
8995
"""Returns an instance of S3IterableDataset using the S3 URI provided.
9096
@@ -93,6 +99,7 @@ def from_prefix(
9399
region(str): AWS region of the S3 bucket where the objects are stored.
94100
endpoint(str): AWS endpoint of the S3 bucket where the objects are stored.
95101
transform: Optional callable which is used to transform an S3Reader into the desired type.
102+
s3client_config: Optional S3ClientConfig with parameters for S3 client.
96103
97104
Returns:
98105
S3IterableDataset: An IterableStyle dataset created from S3 objects.
@@ -106,11 +113,16 @@ def from_prefix(
106113
partial(get_objects_from_prefix, s3_uri),
107114
endpoint,
108115
transform=transform,
116+
s3client_config=s3client_config,
109117
)
110118

111119
def _get_client(self):
112120
if self._client is None:
113-
self._client = S3Client(self.region, self.endpoint)
121+
self._client = S3Client(
122+
self.region,
123+
endpoint=self.endpoint,
124+
s3client_config=self._s3client_config,
125+
)
114126
return self._client
115127

116128
def _get_transformed_object(self, bucket_key: S3BucketKeyData) -> Any:

0 commit comments

Comments
 (0)