Skip to content

Commit b8fc89e

Browse files
committed
allow different credentials
1 parent 9589881 commit b8fc89e

File tree

1 file changed

+23
-4
lines changed

1 file changed

+23
-4
lines changed

src/maggma/stores/azure.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Advanced Stores for connecting to Microsoft Azure data.
33
"""
44

5+
import importlib
56
import os
67
import threading
78
import warnings
@@ -11,7 +12,7 @@
1112
from concurrent.futures.thread import ThreadPoolExecutor
1213
from hashlib import sha1
1314
from json import dumps
14-
from typing import Optional, Union
15+
from typing import Literal, Optional, Union
1516

1617
import msgpack # type: ignore
1718
from monty.msgpack import default as monty_default
@@ -23,15 +24,29 @@
2324
import azure
2425
import azure.storage.blob as azure_blob
2526
from azure.core.exceptions import ResourceExistsError
26-
from azure.identity import DefaultAzureCredential
2727
from azure.storage.blob import BlobServiceClient, ContainerClient
28+
29+
2830
except (ImportError, ModuleNotFoundError):
2931
azure_blob = None # type: ignore
3032
ContainerClient = None
3133

3234

3335
AZURE_KEY_SANITIZE = {"-": "_", ".": "_"}
3436

37+
CredentialType = Literal[
38+
"DefaultAzureCredential",
39+
"AzureCliCredential",
40+
"ManagedIdentityCredential",
41+
]
42+
43+
44+
def _get_azure_credential(credential_class: str):
45+
"""Import the azure.identity module and return the credential class."""
46+
module_name = "azure.identity"
47+
credential_class = getattr(importlib.import_module(module_name), credential_class)
48+
return credential_class()
49+
3550

3651
class AzureBlobStore(Store):
3752
"""
@@ -45,6 +60,7 @@ def __init__(
4560
index: Store,
4661
container_name: str,
4762
azure_client_info: Optional[Union[str, dict]] = None,
63+
credential_type: CredentialType = "DefaultAzureCredential",
4864
compress: bool = False,
4965
sub_dir: Optional[str] = None,
5066
workers: int = 1,
@@ -69,6 +85,8 @@ def __init__(
6985
BlobServiceClient.
7086
Currently supported keywords:
7187
- connection_string: a connection string for the Azure blob
88+
credential_type: the type of credential to use to authenticate with Azure.
89+
Default is "DefaultAzureCredential".
7290
compress: compress files inserted into the store
7391
sub_dir: (optional) subdirectory of the container to store the data.
7492
When defined, a final "/" will be added if not already present.
@@ -104,6 +122,7 @@ def __init__(
104122
key_sanitize_dict = AZURE_KEY_SANITIZE
105123
self.key_sanitize_dict = key_sanitize_dict
106124
self.create_container = create_container
125+
self.credential_type = credential_type
107126

108127
# Force the key to be the same as the index
109128
assert isinstance(
@@ -351,8 +370,8 @@ def _get_service_client(self):
351370
if not hasattr(self._thread_local, "container"):
352371
if isinstance(self.azure_client_info, str):
353372
# assume it is the account_url and that the connection is passwordless
354-
default_credential = DefaultAzureCredential()
355-
return BlobServiceClient(self.azure_client_info, credential=default_credential)
373+
credentials_ = _get_azure_credential(self.credential_type)
374+
return BlobServiceClient(self.azure_client_info, credential=credentials_)
356375

357376
if isinstance(self.azure_client_info, dict):
358377
connection_string = self.azure_client_info.get("connection_string")

0 commit comments

Comments
 (0)