22Advanced Stores for connecting to Microsoft Azure data.
33"""
44
5+ import importlib
56import os
67import threading
78import warnings
1112from concurrent .futures .thread import ThreadPoolExecutor
1213from hashlib import sha1
1314from json import dumps
14- from typing import Optional , Union
15+ from typing import Literal , Optional , Union
1516
1617import msgpack # type: ignore
1718from monty .msgpack import default as monty_default
2324 import azure
2425 import azure .storage .blob as azure_blob
2526 from azure .core .exceptions import ResourceExistsError
26- from azure .identity import DefaultAzureCredential
2727 from azure .storage .blob import BlobServiceClient , ContainerClient
28+
29+
2830except (ImportError , ModuleNotFoundError ):
2931 azure_blob = None # type: ignore
3032 ContainerClient = None
3133
3234
3335AZURE_KEY_SANITIZE = {"-" : "_" , "." : "_" }
3436
37+ CredentialType = Literal [
38+ "DefaultAzureCredential" ,
39+ "AzureCliCredential" ,
40+ "ManagedIdentityCredential" ,
41+ ]
42+
43+
44+ def _get_azure_credential (credential_class : str ):
45+ """Import the azure.identity module and return the credential class."""
46+ module_name = "azure.identity"
47+ credential_class = getattr (importlib .import_module (module_name ), credential_class )
48+ return credential_class ()
49+
3550
3651class AzureBlobStore (Store ):
3752 """
@@ -45,6 +60,7 @@ def __init__(
4560 index : Store ,
4661 container_name : str ,
4762 azure_client_info : Optional [Union [str , dict ]] = None ,
63+ credential_type : CredentialType = "DefaultAzureCredential" ,
4864 compress : bool = False ,
4965 sub_dir : Optional [str ] = None ,
5066 workers : int = 1 ,
@@ -69,6 +85,8 @@ def __init__(
6985 BlobServiceClient.
7086 Currently supported keywords:
7187 - connection_string: a connection string for the Azure blob
88+ credential_type: the type of credential to use to authenticate with Azure.
89+ Default is "DefaultAzureCredential".
7290 compress: compress files inserted into the store
7391 sub_dir: (optional) subdirectory of the container to store the data.
7492 When defined, a final "/" will be added if not already present.
@@ -104,6 +122,7 @@ def __init__(
104122 key_sanitize_dict = AZURE_KEY_SANITIZE
105123 self .key_sanitize_dict = key_sanitize_dict
106124 self .create_container = create_container
125+ self .credential_type = credential_type
107126
108127 # Force the key to be the same as the index
109128 assert isinstance (
@@ -351,8 +370,8 @@ def _get_service_client(self):
351370 if not hasattr (self ._thread_local , "container" ):
352371 if isinstance (self .azure_client_info , str ):
353372 # assume it is the account_url and that the connection is passwordless
354- default_credential = DefaultAzureCredential ( )
355- return BlobServiceClient (self .azure_client_info , credential = default_credential )
373+ credentials_ = _get_azure_credential ( self . credential_type )
374+ return BlobServiceClient (self .azure_client_info , credential = credentials_ )
356375
357376 if isinstance (self .azure_client_info , dict ):
358377 connection_string = self .azure_client_info .get ("connection_string" )
0 commit comments