Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ Unreleased
----------

- Added "adlfs" to library's default user agent
- Fix issue where ``AzureBlobFile`` did not respect ``location_mode`` parameter
from parent ``AzureBlobFileSystem`` when using SAS credentials and connecting to
new SDK clients.


2024.12.0
Expand Down
1 change: 1 addition & 0 deletions adlfs/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2093,6 +2093,7 @@ def connect_client(self):
elif self.fs.sas_token is not None:
self.container_client = _create_aio_blob_service_client(
account_url=self.fs.account_url + self.fs.sas_token,
location_mode=self.fs.location_mode,
).get_container_client(self.container_name)
else:
self.container_client = _create_aio_blob_service_client(
Expand Down
4 changes: 3 additions & 1 deletion adlfs/tests/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
URL = "http://127.0.0.1:10000"
HOST = "127.0.0.1:10000"
URL = f"http://{HOST}"
ACCOUNT_NAME = "devstoreaccount1"
KEY = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==" # NOQA
CONN_STR = f"DefaultEndpointsProtocol=http;AccountName={ACCOUNT_NAME};AccountKey={KEY};BlobEndpoint={URL}/{ACCOUNT_NAME};" # NOQA
SAS_TOKEN = "not-a-real-sas-token"
DEFAULT_VERSION_ID = "1970-01-01T00:00:00.0000000Z"
LATEST_VERSION_ID = "2022-01-01T00:00:00.0000000Z"
205 changes: 205 additions & 0 deletions adlfs/tests/test_connect_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
from unittest import mock

import azure.storage.blob
import pytest
from azure.storage.blob.aio import BlobServiceClient as AIOBlobServiceClient

from adlfs import AzureBlobFile, AzureBlobFileSystem
from adlfs.tests.constants import ACCOUNT_NAME, CONN_STR, HOST, KEY, SAS_TOKEN
from adlfs.utils import __version__ as __version__


@pytest.fixture()
def mock_from_connection_string(mocker):
return mocker.patch.object(
AIOBlobServiceClient,
"from_connection_string",
autospec=True,
side_effect=AIOBlobServiceClient.from_connection_string,
)


@pytest.fixture()
def mock_service_client_init(mocker):
return mocker.patch.object(
AIOBlobServiceClient,
"__init__",
autospec=True,
side_effect=AIOBlobServiceClient.__init__,
)


def get_expected_client_init_call(
account_url,
credential=None,
location_mode="primary",
user_agent=f"adlfs/{__version__}",
):
call_kwargs = {
"account_url": account_url,
}
if credential is not None:
call_kwargs["credential"] = credential
if location_mode is not None:
call_kwargs["_location_mode"] = location_mode
if user_agent is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the user agent is never none, does there need to be an if statement?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good callout. I mainly did it out of symmetry with the other kwargs and we could leverage that if we ever add the ability to update the user agent in adlfs. But we don't today so we can always set it in the expected call.

call_kwargs["user_agent"] = user_agent
return mock.call(mock.ANY, **call_kwargs)


def get_expected_client_from_connection_string_call(
conn_str=None,
user_agent=f"adlfs/{__version__}",
):
call_kwargs = {
"conn_str": conn_str,
}
if user_agent is not None:
call_kwargs["user_agent"] = user_agent
return mock.call(**call_kwargs)


def assert_client_create_calls(
mock_client_create_method,
expected_create_call,
expected_call_count=1,
):
expected_call_args_list = [expected_create_call for _ in range(expected_call_count)]
assert mock_client_create_method.call_args_list == expected_call_args_list


def ensure_no_api_calls_on_close(file_obj):
# Marks the file-like object as closed to prevent any API calls during an invocation of
# close(), which can occur during garbage collection or direct invocation.
#
# This is important for test cases where we do not want to make an API request whether:
#
# * The test would hang because Azurite is not configured to use SSL and adlfs always sets SSL
# for SDK clients created via their initializer.
#
# * The filesystem is configured to use the secondary location which can by-pass the location
# that Azurite is running.
file_obj.closed = True


@pytest.mark.parametrize(
"fs_kwargs,expected_client_init_call",
[
(
{"account_name": ACCOUNT_NAME, "account_key": KEY},
get_expected_client_init_call(
account_url=f"https://{ACCOUNT_NAME}.blob.core.windows.net",
credential=KEY,
),
),
(
{"account_name": ACCOUNT_NAME, "credential": SAS_TOKEN},
get_expected_client_init_call(
account_url=f"https://{ACCOUNT_NAME}.blob.core.windows.net",
credential=SAS_TOKEN,
),
),
(
{"account_name": ACCOUNT_NAME, "sas_token": SAS_TOKEN},
get_expected_client_init_call(
account_url=f"https://{ACCOUNT_NAME}.blob.core.windows.net?{SAS_TOKEN}",
),
),
# Anonymous connection
(
{"account_name": ACCOUNT_NAME},
get_expected_client_init_call(
account_url=f"https://{ACCOUNT_NAME}.blob.core.windows.net",
location_mode=None,
),
),
# Override host
(
{
"account_name": ACCOUNT_NAME,
"account_host": HOST,
"sas_token": SAS_TOKEN,
},
get_expected_client_init_call(
account_url=f"https://{HOST}?{SAS_TOKEN}",
),
),
# Override location mode
(
{
"account_name": ACCOUNT_NAME,
"account_key": KEY,
"location_mode": "secondary",
},
get_expected_client_init_call(
account_url=f"https://{ACCOUNT_NAME}.blob.core.windows.net",
credential=KEY,
location_mode="secondary",
),
),
(
{
"account_name": ACCOUNT_NAME,
"credential": SAS_TOKEN,
"location_mode": "secondary",
},
get_expected_client_init_call(
account_url=f"https://{ACCOUNT_NAME}.blob.core.windows.net",
credential=SAS_TOKEN,
location_mode="secondary",
),
),
(
{
"account_name": ACCOUNT_NAME,
"sas_token": SAS_TOKEN,
"location_mode": "secondary",
},
get_expected_client_init_call(
account_url=f"https://{ACCOUNT_NAME}.blob.core.windows.net?{SAS_TOKEN}",
location_mode="secondary",
),
),
],
)
def test_connect_initializer(
storage: azure.storage.blob.BlobServiceClient,
mock_service_client_init,
fs_kwargs,
expected_client_init_call,
):
fs = AzureBlobFileSystem(skip_instance_cache=True, **fs_kwargs)
assert_client_create_calls(mock_service_client_init, expected_client_init_call)

f = AzureBlobFile(fs, "data/root/a/file.txt", mode="wb")
f.connect_client()
ensure_no_api_calls_on_close(f)
assert_client_create_calls(
mock_service_client_init,
expected_client_init_call,
expected_call_count=2,
)


def test_connect_connection_str(
storage: azure.storage.blob.BlobServiceClient, mock_from_connection_string
):
fs = AzureBlobFileSystem(
account_name=storage.account_name,
connection_string=CONN_STR,
skip_instance_cache=True,
)
expected_from_connection_str_call = get_expected_client_from_connection_string_call(
conn_str=CONN_STR,
)
assert_client_create_calls(
mock_from_connection_string, expected_from_connection_str_call
)

f = AzureBlobFile(fs, "data/root/a/file.txt", mode="rb")
f.connect_client()
assert_client_create_calls(
mock_from_connection_string,
expected_from_connection_str_call,
expected_call_count=2,
)
83 changes: 0 additions & 83 deletions adlfs/tests/test_user_agent.py

This file was deleted.