Skip to content

Commit 27f553d

Browse files
authored
fix: azure connection check (#503)
1 parent f927a15 commit 27f553d

File tree

3 files changed

+62
-68
lines changed

3 files changed

+62
-68
lines changed

packages/ragbits-core/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
## Unreleased
44
- Add new fusion strategies for the hybrid vector store: RRF and DBSF (#413)
5-
65
- move sources from ragbits-document-search to ragbits-core (#496)
6+
- adding connection check to Azure get_blob_service (#502)
77

88
## 0.13.0 (2025-04-02)
99
- Make the score in VectorStoreResult consistent (always bigger is better)

packages/ragbits-core/src/ragbits/core/sources/azure.py

Lines changed: 46 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,18 @@
22
from collections.abc import Sequence
33
from contextlib import suppress
44
from pathlib import Path
5-
from typing import ClassVar, Optional
5+
from typing import ClassVar
66
from urllib.parse import urlparse
77

88
from ragbits.core.audit import trace, traceable
9+
from ragbits.core.sources.base import Source, get_local_storage_dir
10+
from ragbits.core.sources.exceptions import SourceConnectionError, SourceNotFoundError
11+
from ragbits.core.utils.decorators import requires_dependencies
912

1013
with suppress(ImportError):
1114
from azure.core.exceptions import ResourceNotFoundError
1215
from azure.identity import DefaultAzureCredential
13-
from azure.storage.blob import BlobServiceClient
14-
15-
from ragbits.core.sources.base import Source, get_local_storage_dir
16-
from ragbits.core.sources.exceptions import SourceConnectionError, SourceNotFoundError
17-
from ragbits.core.utils.decorators import requires_dependencies
16+
from azure.storage.blob import BlobServiceClient, ExponentialRetry
1817

1918

2019
class AzureBlobStorageSource(Source):
@@ -26,7 +25,6 @@ class AzureBlobStorageSource(Source):
2625
account_name: str
2726
container_name: str
2827
blob_name: str
29-
_blob_service: Optional["BlobServiceClient"] = None
3028

3129
@property
3230
def id(self) -> str:
@@ -35,49 +33,6 @@ def id(self) -> str:
3533
"""
3634
return f"azure://{self.account_name}/{self.container_name}/{self.blob_name}"
3735

38-
@classmethod
39-
@requires_dependencies(["azure.storage.blob", "azure.identity"], "azure")
40-
async def _get_blob_service(cls, account_name: str) -> "BlobServiceClient":
41-
"""
42-
Returns an authenticated BlobServiceClient instance.
43-
44-
Priority:
45-
1. DefaultAzureCredential (if account_name is set and authentication succeeds).
46-
2. Connection string (if authentication with DefaultAzureCredential fails).
47-
48-
If neither method works, an error is raised.
49-
50-
Args:
51-
account_name: The name of the Azure Blob Storage account.
52-
53-
Returns:
54-
BlobServiceClient: The authenticated Blob Storage client.
55-
56-
Raises:
57-
ValueError: If the authentication fails.
58-
"""
59-
try:
60-
credential = DefaultAzureCredential()
61-
account_url = f"https://{account_name}.blob.core.windows.net"
62-
cls._blob_service = BlobServiceClient(account_url=account_url, credential=credential)
63-
return cls._blob_service
64-
except Exception as e:
65-
print(f"Warning: Failed to authenticate using DefaultAzureCredential. \nError: {str(e)}")
66-
67-
connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING")
68-
if connection_string:
69-
try:
70-
cls._blob_service = BlobServiceClient.from_connection_string(conn_str=connection_string)
71-
return cls._blob_service
72-
except Exception as e:
73-
raise ValueError("Failed to authenticate using connection string.") from e
74-
75-
# If neither method works, raise an error
76-
raise ValueError(
77-
"No authentication method available. "
78-
"Provide an account_name for identity-based authentication or a connection string."
79-
)
80-
8136
@requires_dependencies(["azure.storage.blob", "azure.core.exceptions"], "azure")
8237
async def fetch(self) -> Path:
8338
"""
@@ -95,7 +50,7 @@ async def fetch(self) -> Path:
9550
path = container_local_dir / self.blob_name
9651
with trace(account_name=self.account_name, container=self.container_name, blob=self.blob_name) as outputs:
9752
try:
98-
blob_service = await self._get_blob_service(account_name=self.account_name)
53+
blob_service = self._get_blob_service(self.account_name)
9954
blob_client = blob_service.get_blob_client(container=self.container_name, blob=self.blob_name)
10055
Path(path).parent.mkdir(parents=True, exist_ok=True)
10156
stream = blob_client.download_blob()
@@ -174,12 +129,11 @@ async def list_sources(
174129
List of source objects.
175130
176131
Raises:
177-
ImportError: If the required 'azure-storage-blob' package is not installed
178132
SourceConnectionError: If there's an error connecting to Azure
179133
"""
180134
with trace(account_name=account_name, container=container, blob_name=blob_name) as outputs:
181-
blob_service = await cls._get_blob_service(account_name=account_name)
182135
try:
136+
blob_service = cls._get_blob_service(account_name)
183137
container_client = blob_service.get_container_client(container)
184138
blobs = container_client.list_blobs(name_starts_with=blob_name)
185139
outputs.results = [
@@ -189,3 +143,42 @@ async def list_sources(
189143
return outputs.results
190144
except Exception as e:
191145
raise SourceConnectionError() from e
146+
147+
@staticmethod
148+
def _get_blob_service(account_name: str) -> "BlobServiceClient":
149+
"""
150+
Returns an authenticated BlobServiceClient instance.
151+
152+
Priority:
153+
1. DefaultAzureCredential.
154+
2. Connection string.
155+
156+
Args:
157+
account_name: The name of the Azure Blob Storage account.
158+
159+
Returns:
160+
The authenticated Blob Storage client.
161+
"""
162+
try:
163+
credential = DefaultAzureCredential()
164+
account_url = f"https://{account_name}.blob.core.windows.net"
165+
blob_service = BlobServiceClient(
166+
account_url=account_url,
167+
credential=credential,
168+
retry_policy=ExponentialRetry(retry_total=0),
169+
)
170+
blob_service.get_account_information()
171+
return blob_service
172+
except Exception as first_exc:
173+
if conn_str := os.getenv("AZURE_STORAGE_CONNECTION_STRING", ""):
174+
try:
175+
service = BlobServiceClient.from_connection_string(
176+
conn_str=conn_str,
177+
retry_policy=ExponentialRetry(retry_total=0),
178+
)
179+
service.get_account_information()
180+
return service
181+
except Exception as second_error:
182+
raise second_error from first_exc
183+
184+
raise first_exc

packages/ragbits-core/tests/unit/sources/test_azure.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path, PosixPath
2-
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
2+
from unittest.mock import ANY, AsyncMock, MagicMock, mock_open, patch
33

44
import pytest
55
from azure.core.exceptions import ResourceNotFoundError
@@ -81,31 +81,28 @@ async def test_from_uri_listing():
8181
)
8282

8383

84-
@pytest.mark.asyncio
85-
async def test_get_blob_service_no_credentials():
86-
"""Test that ValueError is raised when no credentials are set."""
84+
def test_get_blob_service_no_credentials():
85+
"""Test that Exception propageted when no credentials are set."""
8786
with (
8887
patch.object(DefaultAzureCredential, "__init__", side_effect=Exception("Authentication failed")),
8988
patch("os.getenv", return_value=None),
90-
pytest.raises(ValueError, match="No authentication method available"),
89+
pytest.raises(Exception, match="Authentication failed"),
9190
):
92-
await AzureBlobStorageSource._get_blob_service(account_name=ACCOUNT_NAME)
91+
AzureBlobStorageSource._get_blob_service(account_name=ACCOUNT_NAME)
9392

9493

95-
@pytest.mark.asyncio
96-
async def test_get_blob_service_with_connection_string():
94+
def test_get_blob_service_with_connection_string():
9795
"""Test that connection string is used when AZURE_STORAGE_ACCOUNT_NAME is not set."""
9896
with (
9997
patch.object(DefaultAzureCredential, "__init__", side_effect=Exception("Authentication failed")),
10098
patch("os.getenv", return_value="mock_connection_string"),
10199
patch("azure.storage.blob.BlobServiceClient.from_connection_string") as mock_from_connection_string,
102100
):
103-
await AzureBlobStorageSource._get_blob_service(account_name="account_name")
104-
mock_from_connection_string.assert_called_once_with(conn_str="mock_connection_string")
101+
AzureBlobStorageSource._get_blob_service(account_name="account_name")
102+
mock_from_connection_string.assert_called_once_with(conn_str="mock_connection_string", retry_policy=ANY)
105103

106104

107-
@pytest.mark.asyncio
108-
async def test_get_blob_service_with_default_credentials():
105+
def test_get_blob_service_with_default_credentials():
109106
"""Test that default credentials are used when the account_name and credentials are available."""
110107
account_url = f"https://{ACCOUNT_NAME}.blob.core.windows.net"
111108

@@ -114,10 +111,14 @@ async def test_get_blob_service_with_default_credentials():
114111
patch("ragbits.core.sources.azure.BlobServiceClient") as mock_blob_client,
115112
patch("azure.storage.blob.BlobServiceClient.from_connection_string") as mock_from_connection_string,
116113
):
117-
await AzureBlobStorageSource._get_blob_service(ACCOUNT_NAME)
114+
AzureBlobStorageSource._get_blob_service(ACCOUNT_NAME)
118115

119116
mock_credential.assert_called_once()
120-
mock_blob_client.assert_called_once_with(account_url=account_url, credential=mock_credential.return_value)
117+
mock_blob_client.assert_called_once_with(
118+
account_url=account_url,
119+
credential=mock_credential.return_value,
120+
retry_policy=ANY,
121+
)
121122
mock_from_connection_string.assert_not_called()
122123

123124

0 commit comments

Comments
 (0)