Skip to content

Commit 11a245d

Browse files
joshharrinJosh Harrington
andauthored
[ml] simplify non-public cloud setup (#36889)
* reduce env var setup for non-public cloud * return direct * fix azureml domain * run black * unintended add * fix registry url test, non-public cloud has "api.ml.azure.{suffix}" * linting --------- Co-authored-by: Josh Harrington <[email protected]>
1 parent 0f1a71b commit 11a245d

File tree

2 files changed

+37
-26
lines changed

2 files changed

+37
-26
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_azure_environments.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -80,44 +80,56 @@ def _get_cloud(cloud: str) -> Dict[str, str]:
8080

8181
def _get_default_cloud_name() -> str:
8282
"""
83-
:return: Configured cloud, defaults to 'AzureCloud'
83+
:return: Configured cloud, defaults to 'AzureCloud' if
84+
AZUREML_CURRENT_CLOUD and ARM_CLOUD_METADATA_URL are not set to dynamically retrieve cloud info.
85+
AZUREML_CURRENT_CLOUD is also set by the SDK based on ARM_CLOUD_METADATA_URL.
8486
:rtype: str
8587
"""
86-
return os.getenv(AZUREML_CLOUD_ENV_NAME, AzureEnvironments.ENV_DEFAULT)
87-
88-
89-
def _get_cloud_details(cloud: Optional[str] = AzureEnvironments.ENV_DEFAULT) -> Dict[str, str]:
88+
current_cloud_env = os.getenv(AZUREML_CLOUD_ENV_NAME)
89+
if current_cloud_env is not None:
90+
return current_cloud_env
91+
arm_metadata_url = os.getenv(ArmConstants.METADATA_URL_ENV_NAME)
92+
if arm_metadata_url is not None:
93+
clouds = _get_clouds_by_metadata_url(arm_metadata_url) # prefer ARM metadata url when set
94+
for cloud_name in clouds: # pylint: disable=consider-using-dict-items
95+
if clouds[cloud_name][EndpointURLS.RESOURCE_MANAGER_ENDPOINT] in arm_metadata_url:
96+
_set_cloud(cloud_name)
97+
return cloud_name
98+
return AzureEnvironments.ENV_DEFAULT
99+
100+
101+
def _get_cloud_details(cloud_name: Optional[str] = None) -> Dict[str, str]:
90102
"""Returns a Cloud endpoints object for the specified Azure Cloud.
91103
92-
:param cloud: cloud name
93-
:type cloud: str
104+
:param cloud_name: cloud name
105+
:type cloud_name: str
94106
:return: azure environment endpoint.
95107
:rtype: Dict[str, str]
96108
"""
97-
if cloud is None:
109+
if cloud_name is None:
110+
cloud_name = _get_default_cloud_name()
98111
module_logger.debug(
99112
"Using the default cloud configuration: '%s'.",
100-
AzureEnvironments.ENV_DEFAULT,
113+
cloud_name,
101114
)
102-
cloud = _get_default_cloud_name()
103-
return _get_cloud(cloud)
115+
return _get_cloud(cloud_name)
104116

105117

106-
def _set_cloud(cloud: str = AzureEnvironments.ENV_DEFAULT):
118+
def _set_cloud(cloud_name: Optional[str] = None):
107119
"""Sets the current cloud.
108120
109-
:param cloud: cloud name
110-
:type cloud: str
121+
:param cloud_name: cloud name
122+
:type cloud_name: str
111123
"""
112-
if cloud is not None:
124+
if cloud_name is not None:
113125
try:
114-
_get_cloud(cloud)
126+
_get_cloud(cloud_name)
115127
except Exception as e:
116-
msg = 'Unknown cloud environment supplied: "{0}".'.format(cloud)
128+
msg = 'Unknown cloud environment supplied: "{0}".'.format(cloud_name)
117129
raise MlException(message=msg, no_personal_data_message=msg) from e
118130
else:
119-
cloud = _get_default_cloud_name()
120-
os.environ[AZUREML_CLOUD_ENV_NAME] = cloud
131+
cloud_name = _get_default_cloud_name()
132+
os.environ[AZUREML_CLOUD_ENV_NAME] = cloud_name
121133

122134

123135
def _get_base_url_from_metadata(cloud_name: Optional[str] = None, is_local_mfe: bool = False) -> str:
@@ -253,16 +265,15 @@ def _get_registry_discovery_url(cloud: dict, cloud_suffix: str = "") -> str:
253265
"""
254266
cloud_name = cloud["name"]
255267
if cloud_name in _environments:
256-
return _environments[cloud_name].registry_url # type: ignore[attr-defined]
257-
268+
return _environments[cloud_name][EndpointURLS.REGISTRY_DISCOVERY_ENDPOINT]
269+
registry_discovery_from_env = os.getenv(ArmConstants.REGISTRY_ENV_URL)
270+
if registry_discovery_from_env is not None:
271+
return registry_discovery_from_env
258272
registry_discovery_region = os.environ.get(
259273
ArmConstants.REGISTRY_DISCOVERY_REGION_ENV_NAME,
260274
ArmConstants.REGISTRY_DISCOVERY_DEFAULT_REGION,
261275
)
262-
registry_discovery_region_default = "https://{}{}.api.azureml.{}/".format(
263-
cloud_name.lower(), registry_discovery_region, cloud_suffix
264-
)
265-
return os.environ.get(ArmConstants.REGISTRY_ENV_URL, registry_discovery_region_default)
276+
return f"https://{cloud_name.lower()}{registry_discovery_region}.api.ml.azure.{cloud_suffix}/"
266277

267278

268279
def _get_clouds_by_metadata_url(metadata_url: str) -> Dict[str, Dict[str, str]]:

sdk/ml/azure-ai-ml/tests/internal_utils/unittests/test_cloud_environments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def test_metadata_registry_endpoint(self, mock_get):
150150
cloud_details = _get_cloud_details("TEST_ENV2")
151151
assert (
152152
cloud_details.get(EndpointURLS.REGISTRY_DISCOVERY_ENDPOINT)
153-
== "https://test_env2west.api.azureml.windows.net/"
153+
== "https://test_env2west.api.ml.azure.windows.net/"
154154
)
155155

156156
@mock.patch.dict(os.environ, {}, clear=True)

0 commit comments

Comments
 (0)