Skip to content

Commit ba35c7f

Browse files
authored
Allow usage of unsigned S3 client (#195)
* Disable signing to make requests without AWS credentials Update S3ClientConfig to pass in the configuration for allowing unsigned requests. * Update instructions
1 parent 06c2312 commit ba35c7f

File tree

11 files changed

+64
-19
lines changed

11 files changed

+64
-19
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
## Unreleased
2+
3+
### New features
4+
* Update S3ClientConfig to pass in the configuration for allowing unsigned requests, under boolean flag `unsigned`.
5+
6+
17
## v1.2.2 (March 22, 2024)
28

39
### New features

doc/DEVELOPMENT.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ Using S3ClientConfig you can set up the following parameters for the underlying
150150
(max number of parts per upload is 10,000, minimum upload part size is 5 MiB).
151151
Part size must have **values between 5MiB and 5GiB.** Is set by default to **8MiB** (may change in future).
152152

153+
* `unsigned(bool)`: Set to true to disable signing S3 requests.
154+
153155
For example this can be passed in like:
154156
```py
155157
from s3torchconnector import S3MapDataset, S3ClientConfig
@@ -165,6 +167,9 @@ s3_map_dataset = S3MapDataset.from_prefix(DATASET_URI, region=REGION, s3client_c
165167
s3_checkpoint = S3Checkpoint(region=REGION, s3client_config=config)
166168
# Works similarly for Lightning checkpoints.
167169
s3_lightning_checkpoint = S3LightningCheckpoint(region=REGION, s3client_config=config)
170+
171+
# Disable signing to make requests without AWS credentials
172+
s3_client = S3Client(region=REGION, s3client_config=S3ClientConfig(unsigned=True))
168173
```
169174

170175
**When modifying the default values for these flags, we strongly recommend to run benchmarking to ensure you are not

s3torchconnector/src/s3torchconnector/_s3client/_mock_s3client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
throughput_target_gbps=self.s3client_config.throughput_target_gbps,
3838
part_size=self.s3client_config.part_size,
3939
user_agent_prefix=self.user_agent_prefix,
40+
unsigned=self.s3client_config.unsigned,
4041
)
4142

4243
def add_object(self, key: str, data: bytes) -> None:

s3torchconnector/src/s3torchconnector/_s3client/_s3client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def _client_builder(self) -> MountpointS3Client:
7777
user_agent_prefix=self._user_agent_prefix,
7878
throughput_target_gbps=self._s3client_config.throughput_target_gbps,
7979
part_size=self._s3client_config.part_size,
80+
unsigned=self._s3client_config.unsigned,
8081
)
8182

8283
def get_object(

s3torchconnector/src/s3torchconnector/_s3client/s3client_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ class S3ClientConfig:
1919

2020
throughput_target_gbps: float = 10.0
2121
part_size: int = 8 * 1024 * 1024
22+
unsigned: bool = False

s3torchconnector/tst/e2e/test_e2e_s3datasets.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch.utils.data.datapipes.datapipe import MapDataPipe
1010
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
1111

12-
from s3torchconnector import S3IterableDataset, S3MapDataset
12+
from s3torchconnector import S3IterableDataset, S3MapDataset, S3ClientConfig
1313

1414

1515
def test_s3iterable_dataset_images_10_from_prefix(image_directory):
@@ -100,6 +100,21 @@ def test_dataset_unpickled_iterates(image_directory):
100100
assert expected == actual
101101

102102

103+
def test_unsigned_client():
104+
s3_uri = "s3://s3torchconnector-demo/geonet/images/"
105+
region = "us-east-1"
106+
s3_dataset = S3MapDataset.from_prefix(
107+
s3_uri=s3_uri,
108+
region=region,
109+
transform=lambda obj: obj.read(),
110+
s3client_config=S3ClientConfig(unsigned=True),
111+
)
112+
s3_dataloader = _pytorch_dataloader(s3_dataset)
113+
assert s3_dataloader is not None
114+
assert isinstance(s3_dataloader.dataset, S3MapDataset)
115+
assert len(s3_dataloader) >= 1296
116+
117+
103118
def _compare_dataloaders(
104119
local_dataloader: DataLoader, s3_dataloader: DataLoader, expected_batch_count: int
105120
):

s3torchconnector/tst/unit/test_s3_client.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def test_s3_client_custom_config(part_size: int, throughput_target_gbps: float):
107107
)
108108
assert s3_client._client.part_size == part_size
109109
assert s3_client._client.throughput_target_gbps == throughput_target_gbps
110+
assert s3_client._client.unsigned is False
110111

111112

112113
@pytest.mark.parametrize(
@@ -130,3 +131,11 @@ def test_s3_client_invalid_part_size_config(part_size: int):
130131
)
131132
# The client is lazily initialized
132133
assert s3_client._client.part_size == part_size
134+
135+
136+
def test_unsigned_s3_client():
137+
s3_client = S3Client(
138+
region=TEST_REGION,
139+
s3client_config=S3ClientConfig(unsigned=True),
140+
)
141+
assert s3_client._client.unsigned is True

s3torchconnectorclient/python/src/s3torchconnectorclient/_mountpoint_s3_client.pyi

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class MountpointS3Client:
1010
region: str
1111
part_size: int
1212
profile: Optional[str]
13-
no_sign_request: bool
13+
unsigned: Optional[bool]
1414
user_agent_prefix: str
1515
endpoint: str
1616

@@ -21,7 +21,7 @@ class MountpointS3Client:
2121
throughput_target_gbps: float = 10.0,
2222
part_size: int = 8 * 1024 * 1024,
2323
profile: Optional[str] = None,
24-
no_sign_request: bool = False,
24+
unsigned: Optional[bool] = False,
2525
endpoint: Optional[str] = None,
2626
): ...
2727
def get_object(self, bucket: str, key: str) -> GetObjectStream: ...
@@ -39,6 +39,7 @@ class MockMountpointS3Client:
3939
region: str
4040
part_size: int
4141
user_agent_prefix: str
42+
unsigned: bool
4243

4344
def __init__(
4445
self,
@@ -48,6 +49,7 @@ class MockMountpointS3Client:
4849
throughput_target_gbps: float = 10.0,
4950
part_size: int = 8 * 1024 * 1024,
5051
user_agent_prefix: str = "mock_client",
52+
unsigned: bool = False,
5153
): ...
5254
def create_mocked_client(self) -> MountpointS3Client: ...
5355
def add_object(self, key: str, data: bytes) -> None: ...

s3torchconnectorclient/python/tst/unit/test_mountpoint_s3_client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def test_put_object_with_storage_class():
243243
# TODO: Add hypothesis setup after aligning on limits
244244
def test_mountpoint_client_pickles():
245245
expected_profile = None
246-
expected_no_sign_request = False
246+
expected_unsigned = False
247247
expected_region = REGION
248248
expected_part_size = 5 * 2**20
249249
expected_throughput_target_gbps = 3.5
@@ -254,7 +254,7 @@ def test_mountpoint_client_pickles():
254254
part_size=expected_part_size,
255255
throughput_target_gbps=expected_throughput_target_gbps,
256256
profile=expected_profile,
257-
no_sign_request=expected_no_sign_request,
257+
unsigned=expected_unsigned,
258258
)
259259
dumped = pickle.dumps(client)
260260
loaded = pickle.loads(dumped)
@@ -271,7 +271,7 @@ def test_mountpoint_client_pickles():
271271
== expected_throughput_target_gbps
272272
)
273273
assert client.profile == loaded.profile == expected_profile
274-
assert client.no_sign_request == loaded.no_sign_request == expected_no_sign_request
274+
assert client.unsigned == loaded.unsigned == expected_unsigned
275275

276276

277277
@pytest.mark.parametrize(

s3torchconnectorclient/rust/src/mock_client.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,21 @@ pub struct PyMockClient {
2626
pub(crate) part_size: usize,
2727
#[pyo3(get)]
2828
pub(crate) user_agent_prefix: String,
29+
#[pyo3(get)]
30+
pub(crate) unsigned: bool,
2931
}
3032

3133
#[pymethods]
3234
impl PyMockClient {
3335
#[new]
34-
#[pyo3(signature = (region, bucket, throughput_target_gbps = 10.0, part_size = 8 * 1024 * 1024, user_agent_prefix="mock_client".to_string()))]
36+
#[pyo3(signature = (region, bucket, throughput_target_gbps = 10.0, part_size = 8 * 1024 * 1024, user_agent_prefix="mock_client".to_string(), unsigned=false))]
3537
pub fn new(
3638
region: String,
3739
bucket: String,
3840
throughput_target_gbps: f64,
3941
part_size: usize,
4042
user_agent_prefix: String,
43+
unsigned: bool,
4144
) -> PyMockClient {
4245
let unordered_list_seed: Option<u64> = None;
4346
let config = MockClientConfig { bucket, part_size, unordered_list_seed };
@@ -48,7 +51,8 @@ impl PyMockClient {
4851
region,
4952
throughput_target_gbps,
5053
part_size,
51-
user_agent_prefix
54+
user_agent_prefix,
55+
unsigned
5256
}
5357
}
5458

@@ -59,7 +63,7 @@ impl PyMockClient {
5963
self.throughput_target_gbps,
6064
self.part_size,
6165
None,
62-
false,
66+
self.unsigned,
6367
self.mock_client.clone(),
6468
None,
6569
)

0 commit comments

Comments
 (0)