Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions src/maggma/stores/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Advanced Stores for connecting to Microsoft Azure data.
"""

import importlib
import os
import threading
import warnings
Expand All @@ -11,7 +12,7 @@
from concurrent.futures.thread import ThreadPoolExecutor
from hashlib import sha1
from json import dumps
from typing import Optional, Union
from typing import Literal, Optional, Union

import msgpack # type: ignore
from monty.msgpack import default as monty_default
Expand All @@ -23,15 +24,34 @@
import azure
import azure.storage.blob as azure_blob
from azure.core.exceptions import ResourceExistsError
from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient, ContainerClient


except (ImportError, ModuleNotFoundError):
azure_blob = None # type: ignore
ContainerClient = None


AZURE_KEY_SANITIZE = {"-": "_", ".": "_"}

CredentialType = Literal[
"DefaultAzureCredential",
"AzureCliCredential",
"ManagedIdentityCredential",
]


def _get_azure_credential(credential_class):
"""Import the azure.identity module and return the credential class.

If the credential_class is a class, return an instance of it.
If the credential_class is a string, import the module first
"""
if isinstance(credential_class, str):
module_name = "azure.identity"
credential_class = getattr(importlib.import_module(module_name), credential_class)
return credential_class()


class AzureBlobStore(Store):
"""
Expand All @@ -45,6 +65,7 @@ def __init__(
index: Store,
container_name: str,
azure_client_info: Optional[Union[str, dict]] = None,
credential_type: CredentialType = "DefaultAzureCredential",
compress: bool = False,
sub_dir: Optional[str] = None,
workers: int = 1,
Expand All @@ -69,6 +90,10 @@ def __init__(
BlobServiceClient.
Currently supported keywords:
- connection_string: a connection string for the Azure blob
credential_type: the type of credential to use to authenticate with Azure.
Default is "DefaultAzureCredential". For serializable stores, provide
a string representation of the credential class. Otherwises, you may
provide the class itself.
compress: compress files inserted into the store
sub_dir: (optional) subdirectory of the container to store the data.
When defined, a final "/" will be added if not already present.
Expand Down Expand Up @@ -104,6 +129,7 @@ def __init__(
key_sanitize_dict = AZURE_KEY_SANITIZE
self.key_sanitize_dict = key_sanitize_dict
self.create_container = create_container
self.credential_type = credential_type

# Force the key to be the same as the index
assert isinstance(
Expand Down Expand Up @@ -351,8 +377,8 @@ def _get_service_client(self):
if not hasattr(self._thread_local, "container"):
if isinstance(self.azure_client_info, str):
# assume it is the account_url and that the connection is passwordless
default_credential = DefaultAzureCredential()
return BlobServiceClient(self.azure_client_info, credential=default_credential)
credentials_ = _get_azure_credential(self.credential_type)
return BlobServiceClient(self.azure_client_info, credential=credentials_)

if isinstance(self.azure_client_info, dict):
connection_string = self.azure_client_info.get("connection_string")
Expand Down
31 changes: 31 additions & 0 deletions tests/stores/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,34 @@ def test_no_login():

with pytest.raises(RuntimeError, match=r".*Could not instantiate BlobServiceClient.*"):
store.connect()


def test_credential_type_valid():
credential_type = "DefaultAzureCredential"
index = MemoryStore("index")
store = AzureBlobStore(
index,
AZURITE_CONTAINER_NAME,
azure_client_info="client_url",
credential_type=credential_type,
)
assert store.credential_type == credential_type

# tricks the store into thinking you already
# provided the blob service client so it skips
# the connection checks. We are only testing that
# the credential import works properly
store.service = True
store.connect()

from azure.identity import DefaultAzureCredential

credential_type = DefaultAzureCredential
index = MemoryStore("index")
store = AzureBlobStore(
index,
AZURITE_CONTAINER_NAME,
azure_client_info="client_url",
credential_type=credential_type,
)
assert not isinstance(store.credential_type, str)
Loading