|
1 | 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
2 | 2 | # // SPDX-License-Identifier: BSD |
3 | 3 | import logging |
4 | | -from unittest.mock import MagicMock |
5 | | - |
6 | 4 | import pytest |
7 | 5 |
|
| 6 | +from hypothesis import given |
| 7 | +from hypothesis.strategies import lists, text |
| 8 | +from unittest.mock import MagicMock |
| 9 | + |
| 10 | +from s3torchconnector._user_agent import UserAgent |
| 11 | +from s3torchconnector._version import __version__ |
8 | 12 | from s3torchconnector._s3client import S3Client, MockS3Client |
9 | 13 |
|
10 | 14 | TEST_BUCKET = "test-bucket" |
@@ -48,3 +52,33 @@ def test_list_objects_log(s3_client: S3Client, caplog): |
48 | 52 | with caplog.at_level(logging.DEBUG): |
49 | 53 | s3_client.list_objects(TEST_BUCKET, TEST_KEY) |
50 | 54 | assert f"ListObjects {S3_URI}" in caplog.messages |
| 55 | + |
| 56 | + |
| 57 | +def test_s3_client_default_user_agent(): |
| 58 | + s3_client = S3Client(region=TEST_REGION) |
| 59 | + expected_user_agent = f"s3torchconnector/{__version__}" |
| 60 | + assert s3_client.user_agent_prefix == expected_user_agent |
| 61 | + assert s3_client._client.user_agent_prefix == expected_user_agent |
| 62 | + |
| 63 | + |
| 64 | +def test_s3_client_custom_user_agent(): |
| 65 | + s3_client = S3Client( |
| 66 | + region=TEST_REGION, user_agent=UserAgent(["component/version", "metadata"]) |
| 67 | + ) |
| 68 | + expected_user_agent = ( |
| 69 | + f"s3torchconnector/{__version__} (component/version; metadata)" |
| 70 | + ) |
| 71 | + assert s3_client.user_agent_prefix == expected_user_agent |
| 72 | + assert s3_client._client.user_agent_prefix == expected_user_agent |
| 73 | + |
| 74 | + |
| 75 | +@given(lists(text())) |
| 76 | +def test_user_agent_always_starts_with_package_version(comments): |
| 77 | + s3_client = S3Client(region=TEST_REGION, user_agent=UserAgent(comments)) |
| 78 | + expected_prefix = f"s3torchconnector/{__version__}" |
| 79 | + assert s3_client.user_agent_prefix.startswith(expected_prefix) |
| 80 | + assert s3_client._client.user_agent_prefix.startswith(expected_prefix) |
| 81 | + comments_str = "; ".join(filter(None, comments)) |
| 82 | + if comments_str: |
| 83 | + assert comments_str in s3_client.user_agent_prefix |
| 84 | + assert comments_str in s3_client._client.user_agent_prefix |
0 commit comments