Skip to content

Commit 4ed05b8

Browse files
authored
Merge pull request #1020 from jmmshn/jmmshn/cli
[Feature] Allow Different Azure Authentication Methods
2 parents 9589881 + 5f8f988 commit 4ed05b8

File tree

2 files changed

+61
-4
lines changed

2 files changed

+61
-4
lines changed

src/maggma/stores/azure.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Advanced Stores for connecting to Microsoft Azure data.
33
"""
44

5+
import importlib
56
import os
67
import threading
78
import warnings
@@ -11,7 +12,7 @@
1112
from concurrent.futures.thread import ThreadPoolExecutor
1213
from hashlib import sha1
1314
from json import dumps
14-
from typing import Optional, Union
15+
from typing import Literal, Optional, Union
1516

1617
import msgpack # type: ignore
1718
from monty.msgpack import default as monty_default
@@ -23,15 +24,34 @@
2324
import azure
2425
import azure.storage.blob as azure_blob
2526
from azure.core.exceptions import ResourceExistsError
26-
from azure.identity import DefaultAzureCredential
2727
from azure.storage.blob import BlobServiceClient, ContainerClient
28+
29+
2830
except (ImportError, ModuleNotFoundError):
2931
azure_blob = None # type: ignore
3032
ContainerClient = None
3133

3234

3335
AZURE_KEY_SANITIZE = {"-": "_", ".": "_"}
3436

37+
CredentialType = Literal[
38+
"DefaultAzureCredential",
39+
"AzureCliCredential",
40+
"ManagedIdentityCredential",
41+
]
42+
43+
44+
def _get_azure_credential(credential_class):
45+
"""Import the azure.identity module and return the credential class.
46+
47+
If the credential_class is a class, return an instance of it.
48+
If the credential_class is a string, import the module first
49+
"""
50+
if isinstance(credential_class, str):
51+
module_name = "azure.identity"
52+
credential_class = getattr(importlib.import_module(module_name), credential_class)
53+
return credential_class()
54+
3555

3656
class AzureBlobStore(Store):
3757
"""
@@ -45,6 +65,7 @@ def __init__(
4565
index: Store,
4666
container_name: str,
4767
azure_client_info: Optional[Union[str, dict]] = None,
68+
credential_type: CredentialType = "DefaultAzureCredential",
4869
compress: bool = False,
4970
sub_dir: Optional[str] = None,
5071
workers: int = 1,
@@ -69,6 +90,10 @@ def __init__(
6990
BlobServiceClient.
7091
Currently supported keywords:
7192
- connection_string: a connection string for the Azure blob
93+
credential_type: the type of credential to use to authenticate with Azure.
94+
Default is "DefaultAzureCredential". For serializable stores, provide
95+
a string representation of the credential class. Otherwises, you may
96+
provide the class itself.
7297
compress: compress files inserted into the store
7398
sub_dir: (optional) subdirectory of the container to store the data.
7499
When defined, a final "/" will be added if not already present.
@@ -104,6 +129,7 @@ def __init__(
104129
key_sanitize_dict = AZURE_KEY_SANITIZE
105130
self.key_sanitize_dict = key_sanitize_dict
106131
self.create_container = create_container
132+
self.credential_type = credential_type
107133

108134
# Force the key to be the same as the index
109135
assert isinstance(
@@ -351,8 +377,8 @@ def _get_service_client(self):
351377
if not hasattr(self._thread_local, "container"):
352378
if isinstance(self.azure_client_info, str):
353379
# assume it is the account_url and that the connection is passwordless
354-
default_credential = DefaultAzureCredential()
355-
return BlobServiceClient(self.azure_client_info, credential=default_credential)
380+
credentials_ = _get_azure_credential(self.credential_type)
381+
return BlobServiceClient(self.azure_client_info, credential=credentials_)
356382

357383
if isinstance(self.azure_client_info, dict):
358384
connection_string = self.azure_client_info.get("connection_string")

tests/stores/test_azure.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,3 +420,34 @@ def test_no_login():
420420

421421
with pytest.raises(RuntimeError, match=r".*Could not instantiate BlobServiceClient.*"):
422422
store.connect()
423+
424+
425+
def test_credential_type_valid():
426+
credential_type = "DefaultAzureCredential"
427+
index = MemoryStore("index")
428+
store = AzureBlobStore(
429+
index,
430+
AZURITE_CONTAINER_NAME,
431+
azure_client_info="client_url",
432+
credential_type=credential_type,
433+
)
434+
assert store.credential_type == credential_type
435+
436+
# tricks the store into thinking you already
437+
# provided the blob service client so it skips
438+
# the connection checks. We are only testing that
439+
# the credential import works properly
440+
store.service = True
441+
store.connect()
442+
443+
from azure.identity import DefaultAzureCredential
444+
445+
credential_type = DefaultAzureCredential
446+
index = MemoryStore("index")
447+
store = AzureBlobStore(
448+
index,
449+
AZURITE_CONTAINER_NAME,
450+
azure_client_info="client_url",
451+
credential_type=credential_type,
452+
)
453+
assert not isinstance(store.credential_type, str)

0 commit comments

Comments
 (0)