Skip to content
32 changes: 16 additions & 16 deletions mlos_bench/mlos_bench/services/remote/azure/azure_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Union

import azure.identity as azure_id
from azure.core.credentials import TokenCredential
from azure.identity import CertificateCredential, DefaultAzureCredential
from azure.keyvault.secrets import SecretClient
from pytz import UTC

Expand All @@ -20,7 +21,7 @@
_LOG = logging.getLogger(__name__)


class AzureAuthService(Service, SupportsAuth):
class AzureAuthService(Service, SupportsAuth[TokenCredential]):
"""Helper methods to get access to Azure services."""

_REQ_INTERVAL = 300 # = 5 min
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(
[
self.get_access_token,
self.get_auth_headers,
self.get_credential,
],
),
)
Expand All @@ -65,10 +67,7 @@ def __init__(

self._access_token = "RENEW *NOW*"
self._token_expiration_ts = datetime.now(UTC) # Typically, some future timestamp.

# Login as the first identity available, usually ourselves or a managed identity
self._cred: Union[azure_id.DefaultAzureCredential, azure_id.CertificateCredential]
self._cred = azure_id.DefaultAzureCredential()
self._cred: Optional[TokenCredential] = None

# Verify info required for SP auth early
if "spClientId" in self.config:
Expand All @@ -82,18 +81,22 @@ def __init__(
},
)

def _init_sp(self) -> None:
def get_credential(self) -> TokenCredential:
"""Return the Azure SDK credential object."""
# Perform this initialization outside of __init__ so that environment loading tests
# don't need to specifically mock keyvault interactions out
if self._cred is not None:
return self._cred

# Already logged in as SP
if isinstance(self._cred, azure_id.CertificateCredential):
return
self._cred = DefaultAzureCredential()
if "spClientId" not in self.config:
return self._cred

sp_client_id = self.config["spClientId"]
keyvault_name = self.config["keyVaultName"]
cert_name = self.config["certName"]
tenant_id = self.config["tenant"]
_LOG.debug("Log in with Azure Service Principal %s", sp_client_id)

# Get a client for fetching cert info
keyvault_secrets_client = SecretClient(
Expand All @@ -108,23 +111,20 @@ def _init_sp(self) -> None:
cert_bytes = b64decode(secret.value)

# Reauthenticate as the service principal.
self._cred = azure_id.CertificateCredential(
self._cred = CertificateCredential(
tenant_id=tenant_id,
client_id=sp_client_id,
certificate_data=cert_bytes,
)
return self._cred

def get_access_token(self) -> str:
"""Get the access token from Azure CLI, if expired."""
# Ensure we are logged as the Service Principal, if provided
if "spClientId" in self.config:
self._init_sp()

ts_diff = (self._token_expiration_ts - datetime.now(UTC)).total_seconds()
_LOG.debug("Time to renew the token: %.2f sec.", ts_diff)
if ts_diff < self._req_interval:
_LOG.debug("Request new accessToken")
res = self._cred.get_token("https://management.azure.com/.default")
res = self.get_credential().get_token("https://management.azure.com/.default")
self._token_expiration_ts = datetime.fromtimestamp(res.expires_on, tz=UTC)
self._access_token = res.token
_LOG.info("Got new accessToken. Expiration time: %s", self._token_expiration_ts)
Expand Down
14 changes: 10 additions & 4 deletions mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
from typing import Any, Callable, Dict, List, Optional, Set, Union

from azure.core.credentials import TokenCredential
from azure.core.exceptions import ResourceNotFoundError
from azure.storage.fileshare import ShareClient

Expand Down Expand Up @@ -60,20 +61,25 @@ def __init__(
"storageFileShareName",
},
)
assert self._parent is not None and isinstance(
self._parent, SupportsAuth
), "Authorization service not provided. Include service-auth.jsonc?"
self._auth_service: SupportsAuth[TokenCredential] = self._parent
self._share_client: Optional[ShareClient] = None

def _get_share_client(self) -> ShareClient:
"""Get the Azure file share client object."""
if self._share_client is None:
assert self._parent is not None and isinstance(
self._parent, SupportsAuth
), "Authorization service not provided. Include service-auth.jsonc?"
credential = self._auth_service.get_credential()
assert isinstance(
credential, TokenCredential
), f"Expected a TokenCredential, but got {type(credential)} instead."
Copy link
Contributor

@bpkroth bpkroth Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still a late runtime error that I was trying to turn into a config load error in #819

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I made a new change to #819 to handle that now.
From what I can tell the rest of this one is good, but it'd be good to get confirmation that DefaultCredential doesn't need refreshed, especially with the SP part.

self._share_client = ShareClient.from_share_url(
self._SHARE_URL.format(
account_name=self.config["storageAccountName"],
fs_name=self.config["storageFileShareName"],
),
credential=self._parent.get_access_token(),
credential=credential,
token_intent="backup",
)
return self._share_client
Expand Down
16 changes: 14 additions & 2 deletions mlos_bench/mlos_bench/services/types/authenticator_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
#
"""Protocol interface for authentication for the cloud services."""

from typing import Protocol, runtime_checkable
from typing import Protocol, TypeVar, runtime_checkable

T_co = TypeVar("T_co", covariant=True)


@runtime_checkable
class SupportsAuth(Protocol):
class SupportsAuth(Protocol[T_co]):
"""Protocol interface for authentication for the cloud services."""

def get_access_token(self) -> str:
Expand All @@ -30,3 +32,13 @@ def get_auth_headers(self) -> dict:
access_header : dict
HTTP header containing the access token.
"""

def get_credential(self) -> T_co:
"""
Get the credential object for cloud services.

Returns
-------
credential : T
Cloud-specific credential object.
"""
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,28 @@ def test_load_service_config_examples(
config_path: str,
) -> None:
"""Tests loading a config example."""
parent: Service = config_loader_service
config = config_loader_service.load_config(config_path, ConfigSchema.SERVICE)
# Add other services that require a SupportsAuth parent service as necessary.
requires_auth_service_parent = {
"AzureFileShareService",
}
config_class_name = str(config.get("class", "MISSING CLASS")).rsplit(".", maxsplit=1)[-1]
if config_class_name in requires_auth_service_parent:
# AzureFileShareService requires an auth service to be loaded as well.
auth_service_config = config_loader_service.load_config(
"services/remote/mock/mock_auth_service.jsonc",
ConfigSchema.SERVICE,
)
auth_service = config_loader_service.build_service(
config=auth_service_config,
parent=config_loader_service,
)
parent = auth_service
# Make an instance of the class based on the config.
service_inst = config_loader_service.build_service(
config=config,
parent=config_loader_service,
parent=parent,
)
assert service_inst is not None
assert isinstance(service_inst, Service)
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
_LOG = logging.getLogger(__name__)


class MockAuthService(Service, SupportsAuth):
class MockAuthService(Service, SupportsAuth[str]):
"""A collection Service functions for mocking authentication ops."""

def __init__(
Expand All @@ -32,6 +32,7 @@ def __init__(
[
self.get_access_token,
self.get_auth_headers,
self.get_credential,
],
),
)
Expand All @@ -41,3 +42,6 @@ def get_access_token(self) -> str:

def get_auth_headers(self) -> dict:
return {"Authorization": "Bearer " + self.get_access_token()}

def get_credential(self) -> str:
return "MOCK CREDENTIAL"