Skip to content

Commit 13fb31b

Browse files
TalIfarganTal Ifargan
andauthored
provide profile as part of S3ClientConfig (#341)
provide profile as part of S3ClientConfig --------- Co-authored-by: Tal Ifargan <[email protected]>
1 parent 56d632d commit 13fb31b

File tree

5 files changed

+60
-2
lines changed

5 files changed

+60
-2
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ that role.
4141
- Set credentials in the AWS credentials profile file on the local system, located at: `~/.aws/credentials`
4242
on Unix or macOS.
4343
- Set the `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment variables.
44+
- Pass the name of the desired profile that you have configured in `~/.aws/config` and `~/.aws/credentials` to the `S3ClientConfig` object.
4445

4546
### Examples
4647

s3torchconnector/src/s3torchconnector/_s3client/_s3client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def _client_builder(self) -> MountpointS3Client:
136136
return MountpointS3Client(
137137
region=self._region,
138138
endpoint=self._endpoint,
139+
profile=self._s3client_config.profile,
139140
user_agent_prefix=self._user_agent_prefix,
140141
throughput_target_gbps=self._s3client_config.throughput_target_gbps,
141142
part_size=self._s3client_config.part_size,

s3torchconnector/src/s3torchconnector/_s3client/s3client_config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# // SPDX-License-Identifier: BSD
33
from dataclasses import dataclass
4+
from typing import Optional
45

56

67
@dataclass(frozen=True)
@@ -15,12 +16,15 @@ class S3ClientConfig:
1516
(max number of parts per upload is 10,000, minimum upload part size is 5 MiB).
1617
Part size must have values between 5MiB and 5GiB.
1718
8MiB by default (may change in future).
19+
unsigned(bool): Set to true to disable signing S3 requests.
1820
force_path_style(bool): forceful path style addressing for S3 client.
19-
max_attempts(int): amount of retry attempts for retrieable errors
21+
max_attempts(int): amount of retry attempts for retrieable errors.
22+
profile(str): Profile name to use for S3 authentication.
2023
"""
2124

2225
throughput_target_gbps: float = 10.0
2326
part_size: int = 8 * 1024 * 1024
2427
unsigned: bool = False
2528
force_path_style: bool = False
2629
max_attempts: int = 10
30+
profile: Optional[str] = None

s3torchconnector/tst/e2e/test_s3_client.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# // SPDX-License-Identifier: BSD
33

4+
import os
5+
import tempfile
46
import pytest
57
from s3torchconnectorclient import S3Exception
68

7-
from s3torchconnector._s3client import S3Client
9+
from s3torchconnector._s3client import S3Client, S3ClientConfig
810

911
HELLO_WORLD_DATA = b"Hello, World!\n"
12+
TEST_PROFILE_NAME = "test-profile"
1013

1114

1215
def test_no_access_objects_without_profile(empty_directory):
@@ -24,3 +27,46 @@ def test_no_access_objects_without_profile(empty_directory):
2427
filename,
2528
)
2629
put_stream.write(HELLO_WORLD_DATA)
30+
31+
32+
def test_access_objects_with_profile(empty_directory):
33+
if empty_directory.profile_bucket is None:
34+
pytest.skip("No profile bucket configured")
35+
36+
try:
37+
tmp_file = tempfile.NamedTemporaryFile()
38+
tmp_file.write(
39+
f"""[profile default]
40+
aws_access_key_id = {os.getenv("AWS_ACCESS_KEY_ID")}
41+
aws_secret_access_key = {os.getenv("AWS_SECRET_ACCESS_KEY")}
42+
aws_session_token = {os.getenv("AWS_SESSION_TOKEN")}
43+
44+
[profile {TEST_PROFILE_NAME}]
45+
role_arn = {empty_directory.profile_arn}
46+
region = {empty_directory.region}
47+
source_profile = default""".encode()
48+
)
49+
tmp_file.flush()
50+
os.environ["AWS_CONFIG_FILE"] = tmp_file.name
51+
52+
client = S3Client(
53+
empty_directory.region,
54+
s3client_config=S3ClientConfig(profile=TEST_PROFILE_NAME),
55+
)
56+
filename = f"{empty_directory.prefix}hello_world.txt"
57+
58+
put_stream = client.put_object(
59+
empty_directory.profile_bucket,
60+
filename,
61+
)
62+
63+
put_stream.write(HELLO_WORLD_DATA)
64+
put_stream.close()
65+
66+
get_stream = client.get_object(
67+
empty_directory.profile_bucket,
68+
filename,
69+
)
70+
assert b"".join(get_stream) == HELLO_WORLD_DATA
71+
finally:
72+
os.environ["AWS_CONFIG_FILE"] = ""

s3torchconnector/tst/unit/test_s3_client_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,19 @@ def test_default():
1313
assert config.throughput_target_gbps == 10.0
1414
assert config.force_path_style is False
1515
assert config.max_attempts == 10
16+
assert config.profile is None
1617

1718

1819
def test_enable_force_path_style():
1920
config = S3ClientConfig(force_path_style=True)
2021
assert config.force_path_style is True
2122

2223

24+
def test_change_profile():
25+
config = S3ClientConfig(profile="test_profile")
26+
assert config.profile == "test_profile"
27+
28+
2329
@given(part_size=integers(min_value=5 * MiB, max_value=5 * GiB))
2430
def test_part_size_setup(part_size: int):
2531
config = S3ClientConfig(part_size=part_size)

0 commit comments

Comments
 (0)