Skip to content
8 changes: 7 additions & 1 deletion mlos_bench/mlos_bench/services/remote/azure/azure_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Union

import azure.core.credentials as azure_cred
import azure.identity as azure_id
from azure.keyvault.secrets import SecretClient
from pytz import UTC
Expand All @@ -20,7 +21,7 @@
_LOG = logging.getLogger(__name__)


class AzureAuthService(Service, SupportsAuth):
class AzureAuthService(Service, SupportsAuth[azure_cred.TokenCredential]):
"""Helper methods to get access to Azure services."""

_REQ_INTERVAL = 300 # = 5 min
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(
[
self.get_access_token,
self.get_auth_headers,
self.get_credential,
],
),
)
Expand Down Expand Up @@ -133,3 +135,7 @@ def get_access_token(self) -> str:
def get_auth_headers(self) -> dict:
"""Get the authorization part of HTTP headers for REST API calls."""
return {"Authorization": "Bearer " + self.get_access_token()}

def get_credential(self) -> azure_cred.TokenCredential:
"""Return the Azure SDK credential object."""
return self._cred
14 changes: 10 additions & 4 deletions mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
from typing import Any, Callable, Dict, List, Optional, Set, Union

import azure.core.credentials as azure_cred
from azure.core.exceptions import ResourceNotFoundError
from azure.storage.fileshare import ShareClient

Expand Down Expand Up @@ -60,20 +61,25 @@ def __init__(
"storageFileShareName",
},
)
assert self._parent is not None and isinstance(
self._parent, SupportsAuth
), "Authorization service not provided. Include service-auth.jsonc?"
self._auth_service: SupportsAuth[azure_cred.TokenCredential] = self._parent
self._share_client: Optional[ShareClient] = None

def _get_share_client(self) -> ShareClient:
"""Get the Azure file share client object."""
if self._share_client is None:
assert self._parent is not None and isinstance(
self._parent, SupportsAuth
), "Authorization service not provided. Include service-auth.jsonc?"
credential = self._auth_service.get_credential()
assert isinstance(
credential, azure_cred.TokenCredential
), f"Expected a TokenCredential, but got {type(credential)} instead."
Copy link
Contributor

@bpkroth bpkroth Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still a late runtime error that I was trying to turn into a config load error in #819

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I made a new change to #819 to handle that now.
From what I can tell the rest of this one is good, but it'd be good to get confirmation that DefaultCredential doesn't need refreshed, especially with the SP part.

self._share_client = ShareClient.from_share_url(
self._SHARE_URL.format(
account_name=self.config["storageAccountName"],
fs_name=self.config["storageFileShareName"],
),
credential=self._parent.get_access_token(),
credential=credential,
token_intent="backup",
)
return self._share_client
Expand Down
16 changes: 14 additions & 2 deletions mlos_bench/mlos_bench/services/types/authenticator_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
#
"""Protocol interface for authentication for the cloud services."""

from typing import Protocol, runtime_checkable
from typing import Protocol, TypeVar, runtime_checkable

T_co = TypeVar("T_co", covariant=True)


@runtime_checkable
class SupportsAuth(Protocol):
class SupportsAuth(Protocol[T_co]):
"""Protocol interface for authentication for the cloud services."""

def get_access_token(self) -> str:
Expand All @@ -30,3 +32,13 @@ def get_auth_headers(self) -> dict:
access_header : dict
HTTP header containing the access token.
"""

def get_credential(self) -> T_co:
"""
Get the credential object for cloud services.

Returns
-------
credential : T
Cloud-specific credential object.
"""
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,28 @@ def test_load_service_config_examples(
config_path: str,
) -> None:
"""Tests loading a config example."""
parent: Service = config_loader_service
config = config_loader_service.load_config(config_path, ConfigSchema.SERVICE)
# Add other services that require a SupportsAuth parent service as necessary.
requires_auth_service_parent = {
"AzureFileShareService",
}
config_class_name = str(config.get("class", "MISSING CLASS")).rsplit(".", maxsplit=1)[-1]
if config_class_name in requires_auth_service_parent:
# AzureFileShareService requires an auth service to be loaded as well.
auth_service_config = config_loader_service.load_config(
"services/remote/mock/mock_auth_service.jsonc",
ConfigSchema.SERVICE,
)
auth_service = config_loader_service.build_service(
config=auth_service_config,
parent=config_loader_service,
)
parent = auth_service
# Make an instance of the class based on the config.
service_inst = config_loader_service.build_service(
config=config,
parent=config_loader_service,
parent=parent,
)
assert service_inst is not None
assert isinstance(service_inst, Service)
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
_LOG = logging.getLogger(__name__)


class MockAuthService(Service, SupportsAuth):
class MockAuthService(Service, SupportsAuth[str]):
"""A collection Service functions for mocking authentication ops."""

def __init__(
Expand All @@ -32,6 +32,7 @@ def __init__(
[
self.get_access_token,
self.get_auth_headers,
self.get_credential,
],
),
)
Expand All @@ -41,3 +42,6 @@ def get_access_token(self) -> str:

def get_auth_headers(self) -> dict:
return {"Authorization": "Bearer " + self.get_access_token()}

def get_credential(self) -> str:
return "MOCK CREDENTIAL"