diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py index e674a943256..dccb5740ce2 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py @@ -9,7 +9,8 @@ from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Union -import azure.identity as azure_id +from azure.core.credentials import TokenCredential +from azure.identity import CertificateCredential, DefaultAzureCredential from azure.keyvault.secrets import SecretClient from pytz import UTC @@ -20,7 +21,7 @@ _LOG = logging.getLogger(__name__) -class AzureAuthService(Service, SupportsAuth): +class AzureAuthService(Service, SupportsAuth[TokenCredential]): """Helper methods to get access to Azure services.""" _REQ_INTERVAL = 300 # = 5 min @@ -56,6 +57,7 @@ def __init__( [ self.get_access_token, self.get_auth_headers, + self.get_credential, ], ), ) @@ -65,10 +67,7 @@ def __init__( self._access_token = "RENEW *NOW*" self._token_expiration_ts = datetime.now(UTC) # Typically, some future timestamp. - - # Login as the first identity available, usually ourselves or a managed identity - self._cred: Union[azure_id.DefaultAzureCredential, azure_id.CertificateCredential] - self._cred = azure_id.DefaultAzureCredential() + self._cred: Optional[TokenCredential] = None # Verify info required for SP auth early if "spClientId" in self.config: @@ -82,18 +81,22 @@ def __init__( }, ) - def _init_sp(self) -> None: + def get_credential(self) -> TokenCredential: + """Return the Azure SDK credential object.""" # Perform this initialization outside of __init__ so that environment loading tests # don't need to specifically mock keyvault interactions out + if self._cred is not None: + return self._cred - # Already logged in as SP - if isinstance(self._cred, azure_id.CertificateCredential): - return + self._cred = DefaultAzureCredential() + if "spClientId" not in self.config: + return self._cred sp_client_id = self.config["spClientId"] keyvault_name = self.config["keyVaultName"] cert_name = self.config["certName"] tenant_id = self.config["tenant"] + _LOG.debug("Log in with Azure Service Principal %s", sp_client_id) # Get a client for fetching cert info keyvault_secrets_client = SecretClient( @@ -108,23 +111,20 @@ def _init_sp(self) -> None: cert_bytes = b64decode(secret.value) # Reauthenticate as the service principal. - self._cred = azure_id.CertificateCredential( + self._cred = CertificateCredential( tenant_id=tenant_id, client_id=sp_client_id, certificate_data=cert_bytes, ) + return self._cred def get_access_token(self) -> str: """Get the access token from Azure CLI, if expired.""" - # Ensure we are logged as the Service Principal, if provided - if "spClientId" in self.config: - self._init_sp() - ts_diff = (self._token_expiration_ts - datetime.now(UTC)).total_seconds() _LOG.debug("Time to renew the token: %.2f sec.", ts_diff) if ts_diff < self._req_interval: _LOG.debug("Request new accessToken") - res = self._cred.get_token("https://management.azure.com/.default") + res = self.get_credential().get_token("https://management.azure.com/.default") self._token_expiration_ts = datetime.fromtimestamp(res.expires_on, tz=UTC) self._access_token = res.token _LOG.info("Got new accessToken. Expiration time: %s", self._token_expiration_ts) diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py index 5ff1b638a32..6fa447da225 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py @@ -8,6 +8,7 @@ import os from typing import Any, Callable, Dict, List, Optional, Set, Union +from azure.core.credentials import TokenCredential from azure.core.exceptions import ResourceNotFoundError from azure.storage.fileshare import ShareClient @@ -60,20 +61,25 @@ def __init__( "storageFileShareName", }, ) + assert self._parent is not None and isinstance( + self._parent, SupportsAuth + ), "Authorization service not provided. Include service-auth.jsonc?" + self._auth_service: SupportsAuth[TokenCredential] = self._parent self._share_client: Optional[ShareClient] = None def _get_share_client(self) -> ShareClient: """Get the Azure file share client object.""" if self._share_client is None: - assert self._parent is not None and isinstance( - self._parent, SupportsAuth - ), "Authorization service not provided. Include service-auth.jsonc?" + credential = self._auth_service.get_credential() + assert isinstance( + credential, TokenCredential + ), f"Expected a TokenCredential, but got {type(credential)} instead." self._share_client = ShareClient.from_share_url( self._SHARE_URL.format( account_name=self.config["storageAccountName"], fs_name=self.config["storageFileShareName"], ), - credential=self._parent.get_access_token(), + credential=credential, token_intent="backup", ) return self._share_client diff --git a/mlos_bench/mlos_bench/services/types/authenticator_type.py b/mlos_bench/mlos_bench/services/types/authenticator_type.py index 6f99dd6bce3..b01c30d42de 100644 --- a/mlos_bench/mlos_bench/services/types/authenticator_type.py +++ b/mlos_bench/mlos_bench/services/types/authenticator_type.py @@ -4,11 +4,13 @@ # """Protocol interface for authentication for the cloud services.""" -from typing import Protocol, runtime_checkable +from typing import Protocol, TypeVar, runtime_checkable + +T_co = TypeVar("T_co", covariant=True) @runtime_checkable -class SupportsAuth(Protocol): +class SupportsAuth(Protocol[T_co]): """Protocol interface for authentication for the cloud services.""" def get_access_token(self) -> str: @@ -30,3 +32,13 @@ def get_auth_headers(self) -> dict: access_header : dict HTTP header containing the access token. """ + + def get_credential(self) -> T_co: + """ + Get the credential object for cloud services. + + Returns + ------- + credential : T + Cloud-specific credential object. + """ diff --git a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py index 55453270808..e010fd140b9 100644 --- a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py @@ -48,11 +48,28 @@ def test_load_service_config_examples( config_path: str, ) -> None: """Tests loading a config example.""" + parent: Service = config_loader_service config = config_loader_service.load_config(config_path, ConfigSchema.SERVICE) + # Add other services that require a SupportsAuth parent service as necessary. + requires_auth_service_parent = { + "AzureFileShareService", + } + config_class_name = str(config.get("class", "MISSING CLASS")).rsplit(".", maxsplit=1)[-1] + if config_class_name in requires_auth_service_parent: + # AzureFileShareService requires an auth service to be loaded as well. + auth_service_config = config_loader_service.load_config( + "services/remote/mock/mock_auth_service.jsonc", + ConfigSchema.SERVICE, + ) + auth_service = config_loader_service.build_service( + config=auth_service_config, + parent=config_loader_service, + ) + parent = auth_service # Make an instance of the class based on the config. service_inst = config_loader_service.build_service( config=config, - parent=config_loader_service, + parent=parent, ) assert service_inst is not None assert isinstance(service_inst, Service) diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py index 482f9ee2a9b..b1228217a52 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py @@ -13,7 +13,7 @@ _LOG = logging.getLogger(__name__) -class MockAuthService(Service, SupportsAuth): +class MockAuthService(Service, SupportsAuth[str]): """A collection Service functions for mocking authentication ops.""" def __init__( @@ -32,6 +32,7 @@ def __init__( [ self.get_access_token, self.get_auth_headers, + self.get_credential, ], ), ) @@ -41,3 +42,6 @@ def get_access_token(self) -> str: def get_auth_headers(self) -> dict: return {"Authorization": "Bearer " + self.get_access_token()} + + def get_credential(self) -> str: + return "MOCK CREDENTIAL"