Skip to content

Commit 6f2b9a8

Browse files
authored
Refactor UserAgent setup for extensibility (#161)
* Refactor UserAgent setup for extensibility Update MockClient, store just User-Agent prefix string and add assertions for MountpointClient prefix. Add hypothesis test for starts with package/version.
1 parent 7ac85ef commit 6f2b9a8

File tree

8 files changed

+131
-13
lines changed

8 files changed

+131
-13
lines changed

CHANGELOG.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
### New features
44

5-
### Bug Fixes
5+
### Bug Fixes / Improvements
66
* Fix deadlock when enabling CRT debug logs. Removed former experimental method _enable_debug_logging().
7-
7+
* Refactor User-Agent setup for extensibility.
88

99
## v1.1.4 (February 26, 2024)
1010

s3torchconnector/src/s3torchconnector/_s3client/_mock_s3client.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
)
88

99
from . import S3Client
10+
from .._user_agent import UserAgent
1011

1112
"""
1213
_mock_s3client.py
@@ -15,9 +16,20 @@
1516

1617

1718
class MockS3Client(S3Client):
18-
def __init__(self, region: str, bucket: str, part_size: int = 8 * 1024 * 1024):
19-
super().__init__(region)
20-
self._mock_client = MockMountpointS3Client(region, bucket, part_size=part_size)
19+
def __init__(
20+
self,
21+
region: str,
22+
bucket: str,
23+
part_size: int = 8 * 1024 * 1024,
24+
user_agent: UserAgent = None,
25+
):
26+
super().__init__(region, user_agent=user_agent)
27+
self._mock_client = MockMountpointS3Client(
28+
region,
29+
bucket,
30+
part_size=part_size,
31+
user_agent_prefix=self.user_agent_prefix,
32+
)
2133

2234
def add_object(self, key: str, data: bytes) -> None:
2335
self._mock_client.add_object(key, data)

s3torchconnector/src/s3torchconnector/_s3client/_s3client.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing import Optional, Any
88

99
from s3torchconnector import S3Reader, S3Writer
10-
from s3torchconnector._version import user_agent_prefix
1110

1211
from s3torchconnectorclient._mountpoint_s3_client import (
1312
MountpointS3Client,
@@ -16,6 +15,7 @@
1615
GetObjectStream,
1716
)
1817

18+
from s3torchconnector._user_agent import UserAgent
1919

2020
"""
2121
_s3client.py
@@ -32,11 +32,13 @@ def _identity(obj: Any) -> Any:
3232

3333

3434
class S3Client:
35-
def __init__(self, region: str, endpoint: str = None):
35+
def __init__(self, region: str, endpoint: str = None, user_agent: UserAgent = None):
3636
self._region = region
3737
self._endpoint = endpoint
3838
self._real_client = None
3939
self._client_pid = None
40+
user_agent = user_agent or UserAgent()
41+
self._user_agent_prefix = user_agent.prefix
4042

4143
@property
4244
def _client(self) -> MountpointS3Client:
@@ -50,11 +52,15 @@ def _client(self) -> MountpointS3Client:
5052
def region(self) -> str:
5153
return self._region
5254

55+
@property
56+
def user_agent_prefix(self) -> str:
57+
return self._user_agent_prefix
58+
5359
def _client_builder(self) -> MountpointS3Client:
5460
return MountpointS3Client(
5561
region=self._region,
5662
endpoint=self._endpoint,
57-
user_agent_prefix=user_agent_prefix,
63+
user_agent_prefix=self._user_agent_prefix,
5864
)
5965

6066
def get_object(
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
from typing import List
4+
5+
from ._version import __version__
6+
7+
# https://www.rfc-editor.org/rfc/rfc9110#name-user-agent
8+
9+
10+
class UserAgent:
11+
def __init__(self, comments: List[str] = None):
12+
if comments is not None and not isinstance(comments, list):
13+
raise ValueError("Argument comments must be a List[str]")
14+
self._user_agent_prefix = f"{__package__}/{__version__}"
15+
self._comments = comments or []
16+
17+
@property
18+
def prefix(self):
19+
comments_str = "; ".join(filter(None, self._comments))
20+
if comments_str:
21+
return f"{self._user_agent_prefix} ({comments_str})"
22+
return self._user_agent_prefix

s3torchconnector/src/s3torchconnector/_version.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,3 @@
55

66
# __package__ is 's3torchconnector'
77
__version__ = importlib.metadata.version(__package__)
8-
user_agent_prefix = f"{__package__}/{__version__}"

s3torchconnector/tst/unit/test_s3_client.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# // SPDX-License-Identifier: BSD
33
import logging
4-
from unittest.mock import MagicMock
5-
64
import pytest
75

6+
from hypothesis import given
7+
from hypothesis.strategies import lists, text
8+
from unittest.mock import MagicMock
9+
10+
from s3torchconnector._user_agent import UserAgent
11+
from s3torchconnector._version import __version__
812
from s3torchconnector._s3client import S3Client, MockS3Client
913

1014
TEST_BUCKET = "test-bucket"
@@ -48,3 +52,33 @@ def test_list_objects_log(s3_client: S3Client, caplog):
4852
with caplog.at_level(logging.DEBUG):
4953
s3_client.list_objects(TEST_BUCKET, TEST_KEY)
5054
assert f"ListObjects {S3_URI}" in caplog.messages
55+
56+
57+
def test_s3_client_default_user_agent():
58+
s3_client = S3Client(region=TEST_REGION)
59+
expected_user_agent = f"s3torchconnector/{__version__}"
60+
assert s3_client.user_agent_prefix == expected_user_agent
61+
assert s3_client._client.user_agent_prefix == expected_user_agent
62+
63+
64+
def test_s3_client_custom_user_agent():
65+
s3_client = S3Client(
66+
region=TEST_REGION, user_agent=UserAgent(["component/version", "metadata"])
67+
)
68+
expected_user_agent = (
69+
f"s3torchconnector/{__version__} (component/version; metadata)"
70+
)
71+
assert s3_client.user_agent_prefix == expected_user_agent
72+
assert s3_client._client.user_agent_prefix == expected_user_agent
73+
74+
75+
@given(lists(text()))
76+
def test_user_agent_always_starts_with_package_version(comments):
77+
s3_client = S3Client(region=TEST_REGION, user_agent=UserAgent(comments))
78+
expected_prefix = f"s3torchconnector/{__version__}"
79+
assert s3_client.user_agent_prefix.startswith(expected_prefix)
80+
assert s3_client._client.user_agent_prefix.startswith(expected_prefix)
81+
comments_str = "; ".join(filter(None, comments))
82+
if comments_str:
83+
assert comments_str in s3_client.user_agent_prefix
84+
assert comments_str in s3_client._client.user_agent_prefix
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
from __future__ import annotations
4+
5+
from typing import List
6+
7+
import pytest
8+
9+
from s3torchconnector._version import __version__
10+
from s3torchconnector._user_agent import UserAgent
11+
12+
DEFAULT_PREFIX = f"s3torchconnector/{__version__}"
13+
14+
15+
@pytest.mark.parametrize(
16+
"comments, expected_prefix",
17+
[
18+
(None, DEFAULT_PREFIX),
19+
([], DEFAULT_PREFIX),
20+
([""], DEFAULT_PREFIX),
21+
(["", ""], DEFAULT_PREFIX),
22+
(
23+
["component/version", "metadata"],
24+
f"{DEFAULT_PREFIX} (component/version; metadata)",
25+
),
26+
],
27+
)
28+
def test_user_agent_creation(comments: List[str] | None, expected_prefix: str):
29+
user_agent = UserAgent(comments)
30+
assert user_agent.prefix == expected_prefix
31+
32+
33+
def test_default_user_agent_creation():
34+
user_agent = UserAgent()
35+
assert user_agent.prefix == DEFAULT_PREFIX
36+
37+
38+
@pytest.mark.parametrize("invalid_comment", [0, "string"])
39+
def test_invalid_comments_argument(invalid_comment):
40+
with pytest.raises(ValueError, match="Argument comments must be a List\[str\]"):
41+
UserAgent(invalid_comment)

s3torchconnectorclient/rust/src/mock_client.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,20 @@ pub struct PyMockClient {
2424
pub(crate) region: String,
2525
#[pyo3(get)]
2626
pub(crate) part_size: usize,
27+
#[pyo3(get)]
28+
pub(crate) user_agent_prefix: String,
2729
}
2830

2931
#[pymethods]
3032
impl PyMockClient {
3133
#[new]
32-
#[pyo3(signature = (region, bucket, throughput_target_gbps = 10.0, part_size = 8 * 1024 * 1024))]
34+
#[pyo3(signature = (region, bucket, throughput_target_gbps = 10.0, part_size = 8 * 1024 * 1024, user_agent_prefix="mock_client".to_string()))]
3335
pub fn new(
3436
region: String,
3537
bucket: String,
3638
throughput_target_gbps: f64,
3739
part_size: usize,
40+
user_agent_prefix: String,
3841
) -> PyMockClient {
3942
let unordered_list_seed: Option<u64> = None;
4043
let config = MockClientConfig { bucket, part_size, unordered_list_seed };
@@ -45,13 +48,14 @@ impl PyMockClient {
4548
region,
4649
throughput_target_gbps,
4750
part_size,
51+
user_agent_prefix
4852
}
4953
}
5054

5155
fn create_mocked_client(&self) -> MountpointS3Client {
5256
MountpointS3Client::new(
5357
self.region.clone(),
54-
"mock-client".to_string(),
58+
self.user_agent_prefix.to_string(),
5559
self.throughput_target_gbps,
5660
self.part_size,
5761
None,

0 commit comments

Comments
 (0)