diff --git a/CHANGELOG.md b/CHANGELOG.md index 887f008f..38febc6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,9 @@ Unreleased ---------- - Added "adlfs" to library's default user agent +- Fix issue where ``AzureBlobFile`` did not respect ``location_mode`` parameter + from parent ``AzureBlobFileSystem`` when using SAS credentials and connecting to + new SDK clients. 2024.12.0 diff --git a/adlfs/spec.py b/adlfs/spec.py index 92d56515..4c443db1 100644 --- a/adlfs/spec.py +++ b/adlfs/spec.py @@ -2093,6 +2093,7 @@ def connect_client(self): elif self.fs.sas_token is not None: self.container_client = _create_aio_blob_service_client( account_url=self.fs.account_url + self.fs.sas_token, + location_mode=self.fs.location_mode, ).get_container_client(self.container_name) else: self.container_client = _create_aio_blob_service_client( diff --git a/adlfs/tests/constants.py b/adlfs/tests/constants.py index 8f4cf4a3..efc78d38 100644 --- a/adlfs/tests/constants.py +++ b/adlfs/tests/constants.py @@ -1,6 +1,8 @@ -URL = "http://127.0.0.1:10000" +HOST = "127.0.0.1:10000" +URL = f"http://{HOST}" ACCOUNT_NAME = "devstoreaccount1" KEY = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==" # NOQA CONN_STR = f"DefaultEndpointsProtocol=http;AccountName={ACCOUNT_NAME};AccountKey={KEY};BlobEndpoint={URL}/{ACCOUNT_NAME};" # NOQA +SAS_TOKEN = "not-a-real-sas-token" DEFAULT_VERSION_ID = "1970-01-01T00:00:00.0000000Z" LATEST_VERSION_ID = "2022-01-01T00:00:00.0000000Z" diff --git a/adlfs/tests/test_connect_client.py b/adlfs/tests/test_connect_client.py new file mode 100644 index 00000000..49876ffa --- /dev/null +++ b/adlfs/tests/test_connect_client.py @@ -0,0 +1,197 @@ +from unittest import mock + +import azure.storage.blob +import pytest +from azure.storage.blob.aio import BlobServiceClient as AIOBlobServiceClient + +from adlfs import AzureBlobFile, AzureBlobFileSystem +from adlfs.tests.constants import ACCOUNT_NAME, CONN_STR, HOST, KEY, SAS_TOKEN +from adlfs.utils import __version__ as __version__ + + +@pytest.fixture() +def mock_from_connection_string(mocker): + return mocker.patch.object( + AIOBlobServiceClient, + "from_connection_string", + autospec=True, + side_effect=AIOBlobServiceClient.from_connection_string, + ) + + +@pytest.fixture() +def mock_service_client_init(mocker): + return mocker.patch.object( + AIOBlobServiceClient, + "__init__", + autospec=True, + side_effect=AIOBlobServiceClient.__init__, + ) + + +def get_expected_client_init_call( + account_url, + credential=None, + location_mode="primary", +): + call_kwargs = { + "account_url": account_url, + "user_agent": f"adlfs/{__version__}", + } + if credential is not None: + call_kwargs["credential"] = credential + if location_mode is not None: + call_kwargs["_location_mode"] = location_mode + return mock.call(mock.ANY, **call_kwargs) + + +def get_expected_client_from_connection_string_call( + conn_str, +): + return mock.call(conn_str=conn_str, user_agent=f"adlfs/{__version__}") + + +def assert_client_create_calls( + mock_client_create_method, + expected_create_call, + expected_call_count=1, +): + expected_call_args_list = [expected_create_call for _ in range(expected_call_count)] + assert mock_client_create_method.call_args_list == expected_call_args_list + + +def ensure_no_api_calls_on_close(file_obj): + # Marks the file-like object as closed to prevent any API calls during an invocation of + # close(), which can occur during garbage collection or direct invocation. + # + # This is important for test cases where we do not want to make an API request whether: + # + # * The test would hang because Azurite is not configured to use SSL and adlfs always sets SSL + # for SDK clients created via their initializer. + # + # * The filesystem is configured to use the secondary location which can by-pass the location + # that Azurite is running. + file_obj.closed = True + + +@pytest.mark.parametrize( + "fs_kwargs,expected_client_init_call", + [ + ( + {"account_name": ACCOUNT_NAME, "account_key": KEY}, + get_expected_client_init_call( + account_url=f"https://{ACCOUNT_NAME}.blob.core.windows.net", + credential=KEY, + ), + ), + ( + {"account_name": ACCOUNT_NAME, "credential": SAS_TOKEN}, + get_expected_client_init_call( + account_url=f"https://{ACCOUNT_NAME}.blob.core.windows.net", + credential=SAS_TOKEN, + ), + ), + ( + {"account_name": ACCOUNT_NAME, "sas_token": SAS_TOKEN}, + get_expected_client_init_call( + account_url=f"https://{ACCOUNT_NAME}.blob.core.windows.net?{SAS_TOKEN}", + ), + ), + # Anonymous connection + ( + {"account_name": ACCOUNT_NAME}, + get_expected_client_init_call( + account_url=f"https://{ACCOUNT_NAME}.blob.core.windows.net", + location_mode=None, + ), + ), + # Override host + ( + { + "account_name": ACCOUNT_NAME, + "account_host": HOST, + "sas_token": SAS_TOKEN, + }, + get_expected_client_init_call( + account_url=f"https://{HOST}?{SAS_TOKEN}", + ), + ), + # Override location mode + ( + { + "account_name": ACCOUNT_NAME, + "account_key": KEY, + "location_mode": "secondary", + }, + get_expected_client_init_call( + account_url=f"https://{ACCOUNT_NAME}.blob.core.windows.net", + credential=KEY, + location_mode="secondary", + ), + ), + ( + { + "account_name": ACCOUNT_NAME, + "credential": SAS_TOKEN, + "location_mode": "secondary", + }, + get_expected_client_init_call( + account_url=f"https://{ACCOUNT_NAME}.blob.core.windows.net", + credential=SAS_TOKEN, + location_mode="secondary", + ), + ), + ( + { + "account_name": ACCOUNT_NAME, + "sas_token": SAS_TOKEN, + "location_mode": "secondary", + }, + get_expected_client_init_call( + account_url=f"https://{ACCOUNT_NAME}.blob.core.windows.net?{SAS_TOKEN}", + location_mode="secondary", + ), + ), + ], +) +def test_connect_initializer( + storage: azure.storage.blob.BlobServiceClient, + mock_service_client_init, + fs_kwargs, + expected_client_init_call, +): + fs = AzureBlobFileSystem(skip_instance_cache=True, **fs_kwargs) + assert_client_create_calls(mock_service_client_init, expected_client_init_call) + + f = AzureBlobFile(fs, "data/root/a/file.txt", mode="wb") + f.connect_client() + ensure_no_api_calls_on_close(f) + assert_client_create_calls( + mock_service_client_init, + expected_client_init_call, + expected_call_count=2, + ) + + +def test_connect_connection_str( + storage: azure.storage.blob.BlobServiceClient, mock_from_connection_string +): + fs = AzureBlobFileSystem( + account_name=storage.account_name, + connection_string=CONN_STR, + skip_instance_cache=True, + ) + expected_from_connection_str_call = get_expected_client_from_connection_string_call( + conn_str=CONN_STR, + ) + assert_client_create_calls( + mock_from_connection_string, expected_from_connection_str_call + ) + + f = AzureBlobFile(fs, "data/root/a/file.txt", mode="rb") + f.connect_client() + assert_client_create_calls( + mock_from_connection_string, + expected_from_connection_str_call, + expected_call_count=2, + ) diff --git a/adlfs/tests/test_user_agent.py b/adlfs/tests/test_user_agent.py deleted file mode 100644 index 2be43ef5..00000000 --- a/adlfs/tests/test_user_agent.py +++ /dev/null @@ -1,83 +0,0 @@ -import azure.storage.blob -import pytest -from azure.storage.blob.aio import BlobServiceClient as AIOBlobServiceClient - -from adlfs import AzureBlobFile, AzureBlobFileSystem -from adlfs.tests.constants import CONN_STR, KEY -from adlfs.utils import __version__ as __version__ - - -def assert_sets_adlfs_user_agent(mock_client_create_method, expected_call_count=1): - assert len(mock_client_create_method.call_args_list) == expected_call_count - for call in mock_client_create_method.call_args_list: - assert "user_agent" in call.kwargs - assert call.kwargs["user_agent"] == f"adlfs/{__version__}" - - -@pytest.fixture() -def mock_from_connection_string(mocker): - return mocker.patch.object( - AIOBlobServiceClient, - "from_connection_string", - autospec=True, - side_effect=AIOBlobServiceClient.from_connection_string, - ) - - -@pytest.fixture() -def mock_service_client_init(mocker): - return mocker.patch.object( - AIOBlobServiceClient, - "__init__", - autospec=True, - side_effect=AIOBlobServiceClient.__init__, - ) - - -def test_user_agent_blob_file_connection_str( - storage: azure.storage.blob.BlobServiceClient, mock_from_connection_string -): - fs = AzureBlobFileSystem( - account_name=storage.account_name, - connection_string=CONN_STR, - skip_instance_cache=True, - ) - f = AzureBlobFile(fs, "data/root/a/file.txt", mode="rb") - f.connect_client() - assert_sets_adlfs_user_agent(mock_from_connection_string, expected_call_count=2) - - -def test_user_agent_blob_file_initializer( - storage: azure.storage.blob.BlobServiceClient, mock_service_client_init -): - fs = AzureBlobFileSystem( - account_name=storage.account_name, - account_key=KEY, - skip_instance_cache=True, - ) - f = AzureBlobFile(fs, "data/root/a/file.txt", mode="wb") - f.connect_client() - assert_sets_adlfs_user_agent(mock_service_client_init, expected_call_count=2) - # Makes sure no API calls are made that would cause the test to hang because of ssl connection issues - f.closed = True - - -def test_user_agent_connection_str( - storage: azure.storage.blob.BlobServiceClient, mock_from_connection_string -): - AzureBlobFileSystem( - account_name=storage.account_name, - connection_string=CONN_STR, - skip_instance_cache=True, - ) - assert_sets_adlfs_user_agent(mock_from_connection_string) - - -def test_user_agent_initializer( - storage: azure.storage.blob.BlobServiceClient, mock_service_client_init -): - AzureBlobFileSystem( - account_name=storage.account_name, - skip_instance_cache=True, - ) - assert_sets_adlfs_user_agent(mock_service_client_init)