|
| 1 | +# --------------------------------------------------------- |
| 2 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 3 | +# --------------------------------------------------------- |
| 4 | + |
| 5 | +# NOTE: |
| 6 | +# This is a simplified version of the original code from azure-ai-ml: |
| 7 | +# sdk\ml\azure-ai-ml\azure\ai\ml\_azure_environments.py |
| 8 | + |
| 9 | +import asyncio |
| 10 | +import os |
| 11 | + |
| 12 | +from typing import Any, Dict, Final, Mapping, Optional, Sequence, TypedDict |
| 13 | + |
| 14 | +from azure.core import AsyncPipelineClient |
| 15 | +from azure.core.configuration import Configuration |
| 16 | +from azure.core.rest import HttpRequest |
| 17 | +from azure.core.pipeline.policies import ProxyPolicy, AsyncRetryPolicy |
| 18 | + |
| 19 | + |
| 20 | +class AzureEnvironmentMetadata(TypedDict): |
| 21 | + """Configuration for various Azure environments. All endpoints include a trailing slash.""" |
| 22 | + portal_endpoint: str |
| 23 | + """The management portal for the Azure environment (e.g. https://portal.azure.com/)""" |
| 24 | + resource_manager_endpoint: str |
| 25 | + """The API endpoint for Azure control plan (e.g. https://management.azure.com/)""" |
| 26 | + active_directory_endpoint: str |
| 27 | + """The active directory endpoint used for authentication (e.g. https://login.microsoftonline.com/)""" |
| 28 | + aml_resource_endpoint: str |
| 29 | + """The endpoint for Azure Machine Learning resources (e.g. https://ml.azure.com/)""" |
| 30 | + storage_suffix: str |
| 31 | + """The suffix to use for storage endpoint URLs (e.g. core.windows.net)""" |
| 32 | + registry_discovery_endpoint: str |
| 33 | + |
| 34 | + |
| 35 | +_ENV_ARM_CLOUD_METADATA_URL: Final[str] = "ARM_CLOUD_METADATA_URL" |
| 36 | +_ENV_DEFAULT_CLOUD_NAME: Final[str] = "AZUREML_CURRENT_CLOUD" |
| 37 | +_ENV_REGISTRY_DISCOVERY_URL: Final[str] = "REGISTRY_DISCOVERY_ENDPOINT_URL" |
| 38 | +_ENV_REGISTRY_DISCOVERY_REGION: Final[str] = "REGISTRY_DISCOVERY_ENDPOINT_REGION" |
| 39 | +_DEFAULT_REGISTRY_DISCOVERY_REGION: Final[str] = "west" |
| 40 | +_DEFAULT_AZURE_ENV_NAME: Final[str] = "AzureCloud" |
| 41 | + |
| 42 | + |
| 43 | +_ASYNC_LOCK = asyncio.Lock() |
| 44 | +_KNOWN_AZURE_ENVIRONMENTS: Dict[str, AzureEnvironmentMetadata] = { |
| 45 | + _DEFAULT_AZURE_ENV_NAME: { |
| 46 | + "portal_endpoint": "https://portal.azure.com/", |
| 47 | + "resource_manager_endpoint": "https://management.azure.com/", |
| 48 | + "active_directory_endpoint": "https://login.microsoftonline.com/", |
| 49 | + "aml_resource_endpoint": "https://ml.azure.com/", |
| 50 | + "storage_suffix": "core.windows.net", |
| 51 | + "registry_discovery_endpoint": "https://eastus.api.azureml.ms/", |
| 52 | + }, |
| 53 | + "AzureChinaCloud": { |
| 54 | + "portal_endpoint": "https://portal.azure.cn/", |
| 55 | + "resource_manager_endpoint": "https://management.chinacloudapi.cn/", |
| 56 | + "active_directory_endpoint": "https://login.chinacloudapi.cn/", |
| 57 | + "aml_resource_endpoint": "https://ml.azure.cn/", |
| 58 | + "storage_suffix": "core.chinacloudapi.cn", |
| 59 | + "registry_discovery_endpoint": "https://chinaeast2.api.ml.azure.cn/", |
| 60 | + }, |
| 61 | + "AzureUSGovernment": { |
| 62 | + "portal_endpoint": "https://portal.azure.us/", |
| 63 | + "resource_manager_endpoint": "https://management.usgovcloudapi.net/", |
| 64 | + "active_directory_endpoint": "https://login.microsoftonline.us/", |
| 65 | + "aml_resource_endpoint": "https://ml.azure.us/", |
| 66 | + "storage_suffix": "core.usgovcloudapi.net", |
| 67 | + "registry_discovery_endpoint": "https://usgovarizona.api.ml.azure.us/", |
| 68 | + }, |
| 69 | +} |
| 70 | + |
| 71 | + |
| 72 | +class AzureEnvironmentClient: |
| 73 | + DEFAULT_API_VERSION: Final[str] = "2019-05-01" |
| 74 | + DEFAULT_AZURE_CLOUD_NAME: Final[str] = _DEFAULT_AZURE_ENV_NAME |
| 75 | + |
| 76 | + def __init__(self, *, base_url: Optional[str] = None, **kwargs: Any) -> None: |
| 77 | + base_url = base_url if base_url is not None else AzureEnvironmentClient.get_default_metadata_url() |
| 78 | + |
| 79 | + config: Configuration = kwargs.pop("config", Configuration(**kwargs)) |
| 80 | + if config.retry_policy is None: |
| 81 | + config.retry_policy = AsyncRetryPolicy(**kwargs) |
| 82 | + if config.proxy_policy is None and "proxy" in kwargs: |
| 83 | + config.proxy_policy = ProxyPolicy(proxies={"http": kwargs["proxy"], "https": kwargs["proxy"]}) |
| 84 | + |
| 85 | + self._async_client = AsyncPipelineClient(base_url, config=config, **kwargs) |
| 86 | + |
| 87 | + async def get_default_cloud_name_async(self, *, update_cached: bool = True) -> str: |
| 88 | + current_cloud_env = os.getenv(_ENV_DEFAULT_CLOUD_NAME) |
| 89 | + if current_cloud_env is not None: |
| 90 | + return current_cloud_env |
| 91 | + |
| 92 | + arm_metadata_url = os.getenv(_ENV_ARM_CLOUD_METADATA_URL) |
| 93 | + if arm_metadata_url is None: |
| 94 | + return _DEFAULT_AZURE_ENV_NAME |
| 95 | + |
| 96 | + # load clouds from metadata url |
| 97 | + clouds = await self.get_clouds_async(metadata_url=arm_metadata_url, update_cached=update_cached) |
| 98 | + matched = next(filter(lambda t: t[1]["resource_manager_endpoint"] in arm_metadata_url, clouds.items()), None) |
| 99 | + if matched is None: |
| 100 | + return _DEFAULT_AZURE_ENV_NAME |
| 101 | + |
| 102 | + os.environ[_ENV_DEFAULT_CLOUD_NAME] = matched[0] |
| 103 | + return matched[0] |
| 104 | + |
| 105 | + async def get_cloud_async(self, name: str, *, update_cached: bool = True) -> Optional[AzureEnvironmentMetadata]: |
| 106 | + default_endpoint: Optional[str] |
| 107 | + |
| 108 | + def case_insensitive_match(d: Mapping[str, Any], key: str) -> Optional[Any]: |
| 109 | + key = key.strip().lower() |
| 110 | + return next((v for k,v in d.items() if k.strip().lower() == key), None) |
| 111 | + |
| 112 | + async with _ASYNC_LOCK: |
| 113 | + cloud = _KNOWN_AZURE_ENVIRONMENTS.get(name) or case_insensitive_match(_KNOWN_AZURE_ENVIRONMENTS, name) |
| 114 | + if cloud: |
| 115 | + return cloud |
| 116 | + default_endpoint = (_KNOWN_AZURE_ENVIRONMENTS |
| 117 | + .get(_DEFAULT_AZURE_ENV_NAME, {}) |
| 118 | + .get("resource_manager_endpoint")) |
| 119 | + |
| 120 | + metadata_url = self.get_default_metadata_url(default_endpoint) |
| 121 | + clouds = await self.get_clouds_async(metadata_url=metadata_url, update_cached=update_cached) |
| 122 | + cloud_metadata = clouds.get(name) or case_insensitive_match(clouds, name) |
| 123 | + |
| 124 | + return cloud_metadata |
| 125 | + |
| 126 | + async def get_clouds_async( |
| 127 | + self, |
| 128 | + *, |
| 129 | + metadata_url: Optional[str] = None, |
| 130 | + update_cached: bool = True |
| 131 | + ) -> Mapping[str, AzureEnvironmentMetadata]: |
| 132 | + metadata_url = metadata_url or self.get_default_metadata_url() |
| 133 | + |
| 134 | + clouds: Mapping[str, AzureEnvironmentMetadata] |
| 135 | + async with self._async_client.send_request(HttpRequest("GET", metadata_url)) as response: # type: ignore |
| 136 | + response.raise_for_status() |
| 137 | + clouds = await self._parse_cloud_endpoints_async(response.json()) |
| 138 | + |
| 139 | + if update_cached: |
| 140 | + async with _ASYNC_LOCK: |
| 141 | + recursive_update(_KNOWN_AZURE_ENVIRONMENTS, clouds) |
| 142 | + return clouds |
| 143 | + |
| 144 | + async def close(self) -> None: |
| 145 | + await self._async_client.close() |
| 146 | + |
| 147 | + @staticmethod |
| 148 | + def get_default_metadata_url(default_endpoint: Optional[str] = None) -> str: |
| 149 | + default_endpoint = default_endpoint or "https://management.azure.com/" |
| 150 | + metadata_url = os.getenv( |
| 151 | + _ENV_ARM_CLOUD_METADATA_URL, |
| 152 | + f"{default_endpoint}metadata/endpoints?api-version={AzureEnvironmentClient.DEFAULT_API_VERSION}") |
| 153 | + return metadata_url |
| 154 | + |
| 155 | + @staticmethod |
| 156 | + async def _get_registry_discovery_url_async(cloud_name: str, cloud_suffix: str) -> str: |
| 157 | + async with _ASYNC_LOCK: |
| 158 | + discovery_url = _KNOWN_AZURE_ENVIRONMENTS.get(cloud_name, {}).get("registry_discovery_endpoint") |
| 159 | + if discovery_url: |
| 160 | + return discovery_url |
| 161 | + |
| 162 | + discovery_url = os.getenv(_ENV_REGISTRY_DISCOVERY_URL) |
| 163 | + if discovery_url is not None: |
| 164 | + return discovery_url |
| 165 | + |
| 166 | + region = os.getenv(_ENV_REGISTRY_DISCOVERY_REGION, _DEFAULT_REGISTRY_DISCOVERY_REGION) |
| 167 | + return f"https://{cloud_name.lower()}{region}.api.ml.azure.{cloud_suffix}/" |
| 168 | + |
| 169 | + @staticmethod |
| 170 | + async def _parse_cloud_endpoints_async(data: Any) -> Mapping[str, AzureEnvironmentMetadata]: |
| 171 | + # If there is only one cloud, you will get a dict, otherwise a list of dicts |
| 172 | + cloud_data: Sequence[Mapping[str, Any]] = data if not isinstance(data, dict) else [data] |
| 173 | + clouds: Dict[str, AzureEnvironmentMetadata] = {} |
| 174 | + |
| 175 | + def append_trailing_slash(url: str) -> str: |
| 176 | + return url if url.endswith("/") else f"{url}/" |
| 177 | + |
| 178 | + for cloud in cloud_data: |
| 179 | + try: |
| 180 | + name: str = cloud["name"] |
| 181 | + portal_endpoint: str = cloud["portal"] |
| 182 | + cloud_suffix = ".".join(portal_endpoint.split(".")[2:]).replace("/", "") |
| 183 | + discovery_url = await AzureEnvironmentClient._get_registry_discovery_url_async(name, cloud_suffix) |
| 184 | + clouds[name] = { |
| 185 | + "portal_endpoint": append_trailing_slash(portal_endpoint), |
| 186 | + "resource_manager_endpoint": append_trailing_slash(cloud["resourceManager"]), |
| 187 | + "active_directory_endpoint": append_trailing_slash(cloud["authentication"]["loginEndpoint"]), |
| 188 | + "aml_resource_endpoint": append_trailing_slash(f"https://ml.azure.{cloud_suffix}/"), |
| 189 | + "storage_suffix": cloud["suffixes"]["storage"], |
| 190 | + "registry_discovery_endpoint": append_trailing_slash(discovery_url), |
| 191 | + } |
| 192 | + except KeyError: |
| 193 | + continue |
| 194 | + |
| 195 | + return clouds |
| 196 | + |
| 197 | + |
| 198 | +def recursive_update(d: Dict, u: Mapping) -> None: |
| 199 | + """Recursively update a dictionary. |
| 200 | + |
| 201 | + :param Dict d: The dictionary to update. |
| 202 | + :param Mapping u: The mapping to update from. |
| 203 | + """ |
| 204 | + for k, v in u.items(): |
| 205 | + if isinstance(v, Dict) and k in d: |
| 206 | + recursive_update(d[k], v) |
| 207 | + else: |
| 208 | + d[k] = v |
0 commit comments