Skip to content

Commit 39c5524

Browse files
authored
Added support for passing S3 endpoint URL (#124)
* Added support for passing S3 endpoint as an argument to dataset construction API to access dataset stored in AWS S3 compatible storage systems * Make the endpoint argument optional by setting empty string as default value to S3IterableDataset constructor
1 parent b1e90ad commit 39c5524

File tree

13 files changed

+135
-13
lines changed

13 files changed

+135
-13
lines changed

s3torchconnector/src/s3torchconnector/_s3client/_s3client.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@ def _identity(obj: Any) -> Any:
3232

3333

3434
class S3Client:
35-
def __init__(self, region: str):
35+
def __init__(self, region: str, endpoint: str = None):
3636
self._region = region
37+
self._endpoint = endpoint
3738
self._real_client = None
3839
self._client_pid = None
3940

@@ -51,7 +52,9 @@ def region(self) -> str:
5152

5253
def _client_builder(self) -> MountpointS3Client:
5354
return MountpointS3Client(
54-
region=self._region, user_agent_prefix=user_agent_prefix
55+
region=self._region,
56+
endpoint=self._endpoint,
57+
user_agent_prefix=user_agent_prefix,
5558
)
5659

5760
def get_object(self, bucket: str, key: str) -> S3Reader:

s3torchconnector/src/s3torchconnector/s3checkpoint.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ class S3Checkpoint:
1616
torch.load, and torch.save.
1717
"""
1818

19-
def __init__(self, region: str):
19+
def __init__(self, region: str, endpoint: str = None):
2020
self.region = region
21-
self._client = S3Client(region)
21+
self.endpoint = endpoint
22+
self._client = S3Client(region, endpoint)
2223

2324
def reader(self, s3_uri: str) -> S3Reader:
2425
"""Creates an S3Reader from a given s3_uri.

s3torchconnector/src/s3torchconnector/s3iterable_dataset.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,30 +29,38 @@ def __init__(
2929
self,
3030
region: str,
3131
get_dataset_objects: Callable[[S3Client], Iterable[S3BucketKey]],
32+
endpoint: str = None,
3233
transform: Callable[[S3Reader], Any] = identity,
3334
):
3435
self._get_dataset_objects = get_dataset_objects
3536
self._transform = transform
3637
self._region = region
38+
self._endpoint = endpoint
3739
self._client = None
3840

3941
@property
4042
def region(self):
4143
return self._region
4244

45+
@property
46+
def endpoint(self):
47+
return self._endpoint
48+
4349
@classmethod
4450
def from_objects(
4551
cls,
4652
object_uris: Union[str, Iterable[str]],
4753
*,
4854
region: str,
55+
endpoint: str = None,
4956
transform: Callable[[S3Reader], Any] = identity,
5057
):
5158
"""Returns an instance of S3IterableDataset using the S3 URI(s) provided.
5259
5360
Args:
5461
object_uris(str | Iterable[str]): S3 URI of the object(s) desired.
5562
region(str): AWS region of the S3 bucket where the objects are stored.
63+
endpoint(str): AWS endpoint of the S3 bucket where the objects are stored.
5664
transform: Optional callable which is used to transform an S3Reader into the desired type.
5765
5866
Returns:
@@ -63,7 +71,10 @@ def from_objects(
6371
"""
6472
log.info(f"Building {cls.__name__} from_objects")
6573
return cls(
66-
region, partial(get_objects_from_uris, object_uris), transform=transform
74+
region,
75+
partial(get_objects_from_uris, object_uris),
76+
endpoint,
77+
transform=transform,
6778
)
6879

6980
@classmethod
@@ -72,13 +83,15 @@ def from_prefix(
7283
s3_uri: str,
7384
*,
7485
region: str,
86+
endpoint: str = None,
7587
transform: Callable[[S3Reader], Any] = identity,
7688
):
7789
"""Returns an instance of S3IterableDataset using the S3 URI provided.
7890
7991
Args:
8092
s3_uri(str): An S3 URI (prefix) of the object(s) desired. Objects matching the prefix will be included in the returned dataset.
8193
region(str): AWS region of the S3 bucket where the objects are stored.
94+
endpoint(str): AWS endpoint of the S3 bucket where the objects are stored.
8295
transform: Optional callable which is used to transform an S3Reader into the desired type.
8396
8497
Returns:
@@ -89,12 +102,15 @@ def from_prefix(
89102
"""
90103
log.info(f"Building {cls.__name__} from_prefix {s3_uri=}")
91104
return cls(
92-
region, partial(get_objects_from_prefix, s3_uri), transform=transform
105+
region,
106+
partial(get_objects_from_prefix, s3_uri),
107+
endpoint,
108+
transform=transform,
93109
)
94110

95111
def _get_client(self):
96112
if self._client is None:
97-
self._client = S3Client(self.region)
113+
self._client = S3Client(self.region, self.endpoint)
98114
return self._client
99115

100116
def _get_transformed_object(self, bucket_key: S3BucketKey) -> Any:

s3torchconnector/src/s3torchconnector/s3map_dataset.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,24 @@ def __init__(
3030
self,
3131
region: str,
3232
get_dataset_objects: Callable[[S3Client], Iterable[S3BucketKey]],
33+
endpoint: str = None,
3334
transform: Callable[[S3Reader], Any] = identity,
3435
):
3536
self._get_dataset_objects = get_dataset_objects
3637
self._transform = transform
3738
self._region = region
39+
self._endpoint = endpoint
3840
self._client = None
3941
self._bucket_key_pairs = None
4042

4143
@property
4244
def region(self):
4345
return self._region
4446

47+
@property
48+
def endpoint(self):
49+
return self._endpoint
50+
4551
@property
4652
def _dataset_bucket_key_pairs(self) -> List[S3BucketKey]:
4753
if self._bucket_key_pairs is None:
@@ -54,13 +60,15 @@ def from_objects(
5460
object_uris: Union[str, Iterable[str]],
5561
*,
5662
region: str,
63+
endpoint: str = None,
5764
transform: Callable[[S3Reader], Any] = identity,
5865
):
5966
"""Returns an instance of S3MapDataset using the S3 URI(s) provided.
6067
6168
Args:
6269
object_uris(str | Iterable[str]): S3 URI of the object(s) desired.
6370
region(str): AWS region of the S3 bucket where the objects are stored.
71+
endpoint(str): AWS endpoint of the S3 bucket where the objects are stored.
6472
transform: Optional callable which is used to transform an S3Reader into the desired type.
6573
6674
Returns:
@@ -71,7 +79,10 @@ def from_objects(
7179
"""
7280
log.info(f"Building {cls.__name__} from_objects")
7381
return cls(
74-
region, partial(get_objects_from_uris, object_uris), transform=transform
82+
region,
83+
partial(get_objects_from_uris, object_uris),
84+
endpoint,
85+
transform=transform,
7586
)
7687

7788
@classmethod
@@ -80,13 +91,15 @@ def from_prefix(
8091
s3_uri: str,
8192
*,
8293
region: str,
94+
endpoint: str = None,
8395
transform: Callable[[S3Reader], Any] = identity,
8496
):
8597
"""Returns an instance of S3MapDataset using the S3 URI provided.
8698
8799
Args:
88100
s3_uri(str): An S3 URI (prefix) of the object(s) desired. Objects matching the prefix will be included in the returned dataset.
89101
region(str): AWS region of the S3 bucket where the objects are stored.
102+
endpoint(str): AWS endpoint of the S3 bucket where the objects are stored.
90103
transform: Optional callable which is used to transform an S3Reader into the desired type.
91104
92105
Returns:
@@ -97,12 +110,15 @@ def from_prefix(
97110
"""
98111
log.info(f"Building {cls.__name__} from_prefix {s3_uri=}")
99112
return cls(
100-
region, partial(get_objects_from_prefix, s3_uri), transform=transform
113+
region,
114+
partial(get_objects_from_prefix, s3_uri),
115+
endpoint,
116+
transform=transform,
101117
)
102118

103119
def _get_client(self):
104120
if self._client is None:
105-
self._client = S3Client(self.region)
121+
self._client = S3Client(self.region, self.endpoint)
106122
return self._client
107123

108124
def _get_object(self, i: int) -> S3Reader:

s3torchconnector/tst/unit/test_checkpointing.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
TEST_BUCKET = "test-bucket"
3838
TEST_KEY = "test-key"
3939
TEST_REGION = "us-east-1"
40-
40+
TEST_ENDPOINT = "https://s3.us-east-1.amazonaws.com"
4141

4242
scalars = (
4343
none()
@@ -147,6 +147,12 @@ def test_general_checkpointing_untyped_storage_loads_no_modern_pytorch_format(
147147
)
148148

149149

150+
def test_checkpoint_creation_with_region_and_endpoint():
151+
checkpoint = S3Checkpoint(TEST_REGION, endpoint=TEST_ENDPOINT)
152+
assert isinstance(checkpoint, S3Checkpoint)
153+
assert checkpoint.endpoint == TEST_ENDPOINT
154+
155+
150156
def test_checkpoint_seek_logging(caplog):
151157
checkpoint = S3Checkpoint(TEST_REGION)
152158

s3torchconnector/tst/unit/test_s3dataset_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
TEST_KEY = "test-key"
2727
TEST_REGION = "us-east-1"
2828
S3_PREFIX = f"s3://{TEST_BUCKET}"
29+
TEST_ENDPOINT = "https://s3.us-east-1.amazonaws.com"
2930

3031

3132
@pytest.mark.parametrize(

s3torchconnector/tst/unit/test_s3iterable_dataset.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
_create_mock_client_with_dummy_objects,
1313
S3_PREFIX,
1414
TEST_REGION,
15+
TEST_ENDPOINT,
1516
)
1617

1718

@@ -191,6 +192,14 @@ def test_iteration_multiple_times(
191192
_verify_dataset(dataset, expected_keys, lambda data: data._object_info is not None)
192193

193194

195+
def test_dataset_creation_from_prefix_with_region_and_endpoint():
196+
dataset = S3IterableDataset.from_prefix(
197+
S3_PREFIX, region=TEST_REGION, endpoint=TEST_ENDPOINT
198+
)
199+
assert isinstance(dataset, S3IterableDataset)
200+
assert dataset.endpoint == TEST_ENDPOINT
201+
202+
194203
def _verify_dataset(
195204
dataset: S3IterableDataset,
196205
expected_keys: Sequence[str],

s3torchconnector/tst/unit/test_s3mapdataset.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
TEST_REGION,
1313
_create_mock_client_with_dummy_objects,
1414
S3_PREFIX,
15+
TEST_ENDPOINT,
1516
)
1617

1718

@@ -178,6 +179,14 @@ def test_call_len_twice(keys: Sequence[str], length: int):
178179
assert len(dataset) == length
179180

180181

182+
def test_dataset_creation_from_prefix_with_region_and_endpoint():
183+
dataset = S3MapDataset.from_prefix(
184+
S3_PREFIX, region=TEST_REGION, endpoint=TEST_ENDPOINT
185+
)
186+
assert isinstance(dataset, S3MapDataset)
187+
assert dataset.endpoint == TEST_ENDPOINT
188+
189+
181190
def verify_item(dataset: S3MapDataset, index: int, expected_key: str):
182191
data = dataset[index]
183192

s3torchconnectorclient/python/src/s3torchconnectorclient/_mountpoint_s3_client.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ from typing import List, Optional
77

88
class MountpointS3Client:
99
region: str
10+
endpoint: str
1011
user_agent_prefix: str
1112
throughput_target_gbps: float
1213
part_size: int
@@ -16,6 +17,7 @@ class MountpointS3Client:
1617
def __init__(
1718
self,
1819
region: str,
20+
endpoint: str = "",
1921
user_agent_prefix: str = "",
2022
throughput_target_gbps: float = 10.0,
2123
part_size: int = 8 * 1024 * 1024,
@@ -37,6 +39,7 @@ class MockMountpointS3Client:
3739
self,
3840
region: str,
3941
bucket: str,
42+
endpoint: str = "",
4043
throughput_target_gbps: float = 10.0,
4144
part_size: int = 8 * 1024 * 1024,
4245
): ...

s3torchconnectorclient/python/tst/integration/test_mountpoint_s3_integration.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,20 @@ def test_get_object(sample_directory):
3636
assert full_data == HELLO_WORLD_DATA
3737

3838

39+
def test_get_object_with_endpoint(sample_directory):
40+
client = MountpointS3Client(
41+
sample_directory.region,
42+
TEST_USER_AGENT_PREFIX,
43+
endpoint="http://s3.amazonaws.com",
44+
)
45+
stream = client.get_object(
46+
sample_directory.bucket, f"{sample_directory.prefix}hello_world.txt"
47+
)
48+
49+
full_data = b"".join(stream)
50+
assert full_data == HELLO_WORLD_DATA
51+
52+
3953
def test_get_object_with_unpickled_client(sample_directory):
4054
original_client = MountpointS3Client(
4155
sample_directory.region, TEST_USER_AGENT_PREFIX

0 commit comments

Comments
 (0)