Skip to content

Commit 8a94efc

Browse files
authored
Ensure HeadObject is not explicitly called when seeking to the end of a file when dataset is created with from_prefix (#145)
Co-authored-by: Simon Beal <[email protected]>
1 parent dea0865 commit 8a94efc

File tree

9 files changed

+86
-48
lines changed

9 files changed

+86
-48
lines changed

s3torchconnector/src/s3torchconnector/_s3_bucket_iterable.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
ListObjectStream,
1212
)
1313

14+
from ._s3bucket_key_data import S3BucketKeyData
1415
from ._s3client import S3Client
15-
from . import S3Reader
1616

1717

1818
class S3BucketIterable:
@@ -21,7 +21,7 @@ def __init__(self, client: S3Client, bucket: str, prefix: str):
2121
self._bucket = bucket
2222
self._prefix = prefix
2323

24-
def __iter__(self):
24+
def __iter__(self) -> Iterator[S3BucketKeyData]:
2525
# This allows us to iterate multiple times by re-creating the `_list_stream`
2626
return iter(S3BucketIterator(self._client, self._bucket, self._prefix))
2727

@@ -32,10 +32,9 @@ def __init__(self, client: S3Client, bucket: str, prefix: str):
3232
self._bucket = bucket
3333
self._list_stream = _PickleableListObjectStream(client, bucket, prefix)
3434

35-
def __iter__(self) -> Iterator[S3Reader]:
36-
return map(
37-
partial(self._client.from_bucket_and_object_info, self._bucket),
38-
chain.from_iterable(map(_extract_object_info, self._list_stream)),
35+
def __iter__(self) -> Iterator[S3BucketKeyData]:
36+
return chain.from_iterable(
37+
map(partial(_extract_list_results, self._bucket), self._list_stream)
3938
)
4039

4140

@@ -66,5 +65,11 @@ def __setstate__(self, state):
6665
self._list_stream = ListObjectStream._from_state(**state)
6766

6867

69-
def _extract_object_info(list_result: ListObjectResult) -> List[ObjectInfo]:
70-
return list_result.object_info
68+
def _extract_list_results(
69+
bucket: str, list_result: ListObjectResult
70+
) -> Iterator[S3BucketKeyData]:
71+
return map(partial(_extract_object_info, bucket), list_result.object_info)
72+
73+
74+
def _extract_object_info(bucket: str, object_info: ObjectInfo) -> S3BucketKeyData:
75+
return S3BucketKeyData(bucket=bucket, key=object_info.key, object_info=object_info)

s3torchconnector/src/s3torchconnector/_s3bucket_key.py renamed to s3torchconnector/src/s3torchconnector/_s3bucket_key_data.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# // SPDX-License-Identifier: BSD
3-
from typing import NamedTuple
3+
from typing import NamedTuple, Optional
44

5+
from s3torchconnectorclient._mountpoint_s3_client import ObjectInfo
56

6-
class S3BucketKey(NamedTuple):
7+
8+
class S3BucketKeyData(NamedTuple):
79
"""Read-only information about object stored in S3."""
810

911
bucket: str
1012
key: str
13+
object_info: Optional[ObjectInfo] = None

s3torchconnector/src/s3torchconnector/_s3client/_s3client.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,19 @@ def _client_builder(self) -> MountpointS3Client:
5757
user_agent_prefix=user_agent_prefix,
5858
)
5959

60-
def get_object(self, bucket: str, key: str) -> S3Reader:
61-
log.debug(f"GetObject s3://{bucket}/{key}")
60+
def get_object(
61+
self, bucket: str, key: str, *, object_info: Optional[ObjectInfo] = None
62+
) -> S3Reader:
63+
log.debug(f"GetObject s3://{bucket}/{key}, {object_info is None=}")
64+
if object_info is None:
65+
get_object_info = partial(self.head_object, bucket, key)
66+
else:
67+
get_object_info = partial(_identity, object_info)
68+
6269
return S3Reader(
6370
bucket,
6471
key,
65-
get_object_info=partial(self.head_object, bucket, key),
72+
get_object_info=get_object_info,
6673
get_stream=partial(self._get_object_stream, bucket, key),
6774
)
6875

@@ -86,14 +93,3 @@ def list_objects(
8693
def head_object(self, bucket: str, key: str) -> ObjectInfo:
8794
log.debug(f"HeadObject s3://{bucket}/{key}")
8895
return self._client.head_object(bucket, key)
89-
90-
def from_bucket_and_object_info(
91-
self, bucket: str, object_info: ObjectInfo
92-
) -> S3Reader:
93-
log.debug(f"GetObjectWithInfo s3://{bucket}/{object_info.key}")
94-
return S3Reader(
95-
bucket,
96-
object_info.key,
97-
get_object_info=partial(_identity, object_info),
98-
get_stream=partial(self._get_object_stream, bucket, object_info.key),
99-
)

s3torchconnector/src/s3torchconnector/_s3dataset_common.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from ._s3_bucket_iterable import S3BucketIterable
77
from ._s3client import S3Client
88
from . import S3Reader
9-
from ._s3bucket_key import S3BucketKey
9+
from ._s3bucket_key_data import S3BucketKeyData
1010

1111
"""
1212
_s3dataset_common.py
@@ -38,18 +38,15 @@ def parse_s3_uri(uri: str) -> Tuple[str, str]:
3838

3939
def get_objects_from_uris(
4040
object_uris: Union[str, Iterable[str]], client: S3Client
41-
) -> Iterable[S3BucketKey]:
41+
) -> Iterable[S3BucketKeyData]:
4242
if isinstance(object_uris, str):
4343
object_uris = [object_uris]
4444
# TODO: We should be consistent with URIs parsing. Revise if we want to do this upfront or lazily.
4545
bucket_key_pairs = [parse_s3_uri(uri) for uri in object_uris]
4646

47-
return (S3BucketKey(bucket, key) for bucket, key in bucket_key_pairs)
47+
return (S3BucketKeyData(bucket, key) for bucket, key in bucket_key_pairs)
4848

4949

50-
def get_objects_from_prefix(s3_uri: str, client: S3Client) -> Iterable[S3BucketKey]:
50+
def get_objects_from_prefix(s3_uri: str, client: S3Client) -> Iterable[S3BucketKeyData]:
5151
bucket, prefix = parse_s3_uri(s3_uri)
52-
s3objects = S3BucketIterable(client, bucket, prefix)
53-
bucket_key_pairs = (S3BucketKey(obj.bucket, obj.key) for obj in s3objects)
54-
55-
return bucket_key_pairs
52+
return iter(S3BucketIterable(client, bucket, prefix))

s3torchconnector/src/s3torchconnector/s3iterable_dataset.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.utils.data
88

99
from . import S3Reader
10-
from ._s3bucket_key import S3BucketKey
10+
from ._s3bucket_key_data import S3BucketKeyData
1111
from ._s3client import S3Client
1212
from ._s3dataset_common import (
1313
identity,
@@ -28,7 +28,7 @@ class S3IterableDataset(torch.utils.data.IterableDataset):
2828
def __init__(
2929
self,
3030
region: str,
31-
get_dataset_objects: Callable[[S3Client], Iterable[S3BucketKey]],
31+
get_dataset_objects: Callable[[S3Client], Iterable[S3BucketKeyData]],
3232
endpoint: str = None,
3333
transform: Callable[[S3Reader], Any] = identity,
3434
):
@@ -113,9 +113,11 @@ def _get_client(self):
113113
self._client = S3Client(self.region, self.endpoint)
114114
return self._client
115115

116-
def _get_transformed_object(self, bucket_key: S3BucketKey) -> Any:
116+
def _get_transformed_object(self, bucket_key: S3BucketKeyData) -> Any:
117117
return self._transform(
118-
self._get_client().get_object(bucket_key.bucket, bucket_key.key)
118+
self._get_client().get_object(
119+
bucket_key.bucket, bucket_key.key, object_info=bucket_key.object_info
120+
)
119121
)
120122

121123
def __iter__(self) -> Iterator[Any]:

s3torchconnector/src/s3torchconnector/s3map_dataset.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66

77
import torch.utils.data
8-
from s3torchconnector._s3bucket_key import S3BucketKey
8+
from s3torchconnector._s3bucket_key_data import S3BucketKeyData
99

1010
from ._s3client import S3Client
1111
from . import S3Reader
@@ -29,7 +29,7 @@ class S3MapDataset(torch.utils.data.Dataset):
2929
def __init__(
3030
self,
3131
region: str,
32-
get_dataset_objects: Callable[[S3Client], Iterable[S3BucketKey]],
32+
get_dataset_objects: Callable[[S3Client], Iterable[S3BucketKeyData]],
3333
endpoint: str = None,
3434
transform: Callable[[S3Reader], Any] = identity,
3535
):
@@ -49,7 +49,7 @@ def endpoint(self):
4949
return self._endpoint
5050

5151
@property
52-
def _dataset_bucket_key_pairs(self) -> List[S3BucketKey]:
52+
def _dataset_bucket_key_pairs(self) -> List[S3BucketKeyData]:
5353
if self._bucket_key_pairs is None:
5454
self._bucket_key_pairs = list(self._get_dataset_objects(self._get_client()))
5555
return self._bucket_key_pairs
@@ -123,7 +123,9 @@ def _get_client(self):
123123

124124
def _get_object(self, i: int) -> S3Reader:
125125
bucket_key = self._dataset_bucket_key_pairs[i]
126-
return self._get_client().get_object(bucket_key.bucket, bucket_key.key)
126+
return self._get_client().get_object(
127+
bucket_key.bucket, bucket_key.key, object_info=bucket_key.object_info
128+
)
127129

128130
def __getitem__(self, i: int) -> Any:
129131
return self._transform(self._get_object(i))

s3torchconnector/tst/unit/test_s3_client.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# // SPDX-License-Identifier: BSD
33
import logging
4+
from unittest.mock import MagicMock
45

56
import pytest
67

78
from s3torchconnector._s3client import S3Client, MockS3Client
8-
from s3torchconnectorclient._mountpoint_s3_client import ObjectInfo
99

1010
TEST_BUCKET = "test-bucket"
1111
TEST_KEY = "test-key"
@@ -23,16 +23,13 @@ def s3_client() -> S3Client:
2323
def test_get_object_log(s3_client: S3Client, caplog):
2424
with caplog.at_level(logging.DEBUG):
2525
s3_client.get_object(TEST_BUCKET, TEST_KEY)
26-
assert f"GetObject {S3_URI}" in caplog.messages
26+
assert f"GetObject {S3_URI}, object_info is None=True" in caplog.messages
2727

2828

29-
def test_get_object_info_log(s3_client: S3Client, caplog):
29+
def test_get_object_log_with_info(s3_client: S3Client, caplog):
3030
with caplog.at_level(logging.DEBUG):
31-
s3_client.from_bucket_and_object_info(
32-
TEST_BUCKET,
33-
ObjectInfo(TEST_KEY, "", 0, 0, None, None),
34-
)
35-
assert f"GetObjectWithInfo {S3_URI}" in caplog.messages
31+
s3_client.get_object(TEST_BUCKET, TEST_KEY, object_info=MagicMock())
32+
assert f"GetObject {S3_URI}, object_info is None=False" in caplog.messages
3633

3734

3835
def test_head_object_log(s3_client: S3Client, caplog):

s3torchconnector/tst/unit/test_s3iterable_dataset.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# // SPDX-License-Identifier: BSD
33
import logging
4+
from io import SEEK_END
45
from typing import Iterable, Callable, Sequence, Any
6+
from unittest.mock import patch
57

68
import pytest
79

810
from s3torchconnector import S3IterableDataset, S3Reader
11+
from s3torchconnector._s3client import MockS3Client
912

1013
from test_s3dataset_common import (
1114
TEST_BUCKET,
@@ -200,6 +203,21 @@ def test_dataset_creation_from_prefix_with_region_and_endpoint():
200203
assert dataset.endpoint == TEST_ENDPOINT
201204

202205

206+
def test_from_prefix_seek_no_head():
207+
dataset = S3IterableDataset.from_prefix(S3_PREFIX, region=TEST_REGION)
208+
209+
# use mock client for unit testing
210+
client = _create_mock_client_with_dummy_objects(TEST_BUCKET, ["foo"])
211+
dataset._client = client
212+
213+
with patch.object(
214+
MockS3Client, "head_object", wraps=client.head_object
215+
) as head_object:
216+
s3_object = next(iter(dataset))
217+
s3_object.seek(0, SEEK_END)
218+
head_object.assert_not_called()
219+
220+
203221
def _verify_dataset(
204222
dataset: S3IterableDataset,
205223
expected_keys: Sequence[str],

s3torchconnector/tst/unit/test_s3mapdataset.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# // SPDX-License-Identifier: BSD
33
import logging
4+
from io import SEEK_END
45
from typing import Sequence, Callable, Any
6+
from unittest.mock import patch
57

68
import pytest
79

810
from s3torchconnector import S3MapDataset, S3Reader
11+
from s3torchconnector._s3client import MockS3Client
912

1013
from test_s3dataset_common import (
1114
TEST_BUCKET,
@@ -159,6 +162,21 @@ def test_transform_from_objects(
159162
assert list(dataset) == [expected]
160163

161164

165+
def test_from_prefix_seek_no_head():
166+
dataset = S3MapDataset.from_prefix(S3_PREFIX, region=TEST_REGION)
167+
168+
# use mock client for unit testing
169+
client = _create_mock_client_with_dummy_objects(TEST_BUCKET, ["foo"])
170+
dataset._client = client
171+
172+
with patch.object(
173+
MockS3Client, "head_object", wraps=client.head_object
174+
) as head_object:
175+
s3_object = next(iter(dataset))
176+
s3_object.seek(0, SEEK_END)
177+
head_object.assert_not_called()
178+
179+
162180
@pytest.mark.parametrize(
163181
"keys, length",
164182
[

0 commit comments

Comments
 (0)