Skip to content

Commit 83e913b

Browse files
authored
Add back support for Entra ID auth with Prompty files (Azure#40838)
- Adds an async credential provider that supports multi-cloud as per the Promptflow Prompty implementation - Adds a client to query Azure cloud endpoint metadata which is a streamlined version of the code from Promptflow, and azure-ai-ml - Moves TokenScope to be shared and adds entries for Cognitive Services, and AI ML. Updates the rest of the code to use this moved enum - Adds support for the `AI_EVALS_USE_PF_PROMPTY` environment variable which when set to `true` will go back to using the legacy Promptflow Prompty implementation
1 parent 3487e95 commit 83e913b

File tree

21 files changed

+417
-118
lines changed

21 files changed

+417
-118
lines changed

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_azure/_clients.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
from urllib.parse import quote
99
from json.decoder import JSONDecodeError
1010

11-
from azure.core.credentials import TokenCredential, AzureSasCredential
11+
from azure.core.credentials import TokenCredential, AzureSasCredential, AccessToken
1212
from azure.core.rest import HttpResponse
1313
from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
1414
from azure.ai.evaluation._http_utils import HttpPipeline, get_http_client
1515
from azure.ai.evaluation._azure._token_manager import AzureMLTokenManager
16-
from azure.ai.evaluation.simulator._model_tools._identity_manager import TokenScope
16+
from azure.ai.evaluation._constants import TokenScope
1717
from ._models import BlobStoreInfo, Workspace
1818

1919

@@ -61,7 +61,7 @@ def __init__(
6161
self._token_manager: Optional[AzureMLTokenManager] = None
6262
self._credential: Optional[TokenCredential] = credential
6363

64-
def get_token(self) -> str:
64+
def get_token(self) -> AccessToken:
6565
return self._get_token_manager().get_token()
6666

6767
def get_credential(self) -> TokenCredential:
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
5+
# NOTE:
6+
# This is a simplified version of the original code from azure-ai-ml:
7+
# sdk\ml\azure-ai-ml\azure\ai\ml\_azure_environments.py
8+
9+
import asyncio
10+
import os
11+
12+
from typing import Any, Dict, Final, Mapping, Optional, Sequence, TypedDict
13+
14+
from azure.core import AsyncPipelineClient
15+
from azure.core.configuration import Configuration
16+
from azure.core.rest import HttpRequest
17+
from azure.core.pipeline.policies import ProxyPolicy, AsyncRetryPolicy
18+
19+
20+
class AzureEnvironmentMetadata(TypedDict):
21+
"""Configuration for various Azure environments. All endpoints include a trailing slash."""
22+
portal_endpoint: str
23+
"""The management portal for the Azure environment (e.g. https://portal.azure.com/)"""
24+
resource_manager_endpoint: str
25+
"""The API endpoint for Azure control plan (e.g. https://management.azure.com/)"""
26+
active_directory_endpoint: str
27+
"""The active directory endpoint used for authentication (e.g. https://login.microsoftonline.com/)"""
28+
aml_resource_endpoint: str
29+
"""The endpoint for Azure Machine Learning resources (e.g. https://ml.azure.com/)"""
30+
storage_suffix: str
31+
"""The suffix to use for storage endpoint URLs (e.g. core.windows.net)"""
32+
registry_discovery_endpoint: str
33+
34+
35+
_ENV_ARM_CLOUD_METADATA_URL: Final[str] = "ARM_CLOUD_METADATA_URL"
36+
_ENV_DEFAULT_CLOUD_NAME: Final[str] = "AZUREML_CURRENT_CLOUD"
37+
_ENV_REGISTRY_DISCOVERY_URL: Final[str] = "REGISTRY_DISCOVERY_ENDPOINT_URL"
38+
_ENV_REGISTRY_DISCOVERY_REGION: Final[str] = "REGISTRY_DISCOVERY_ENDPOINT_REGION"
39+
_DEFAULT_REGISTRY_DISCOVERY_REGION: Final[str] = "west"
40+
_DEFAULT_AZURE_ENV_NAME: Final[str] = "AzureCloud"
41+
42+
43+
_ASYNC_LOCK = asyncio.Lock()
44+
_KNOWN_AZURE_ENVIRONMENTS: Dict[str, AzureEnvironmentMetadata] = {
45+
_DEFAULT_AZURE_ENV_NAME: {
46+
"portal_endpoint": "https://portal.azure.com/",
47+
"resource_manager_endpoint": "https://management.azure.com/",
48+
"active_directory_endpoint": "https://login.microsoftonline.com/",
49+
"aml_resource_endpoint": "https://ml.azure.com/",
50+
"storage_suffix": "core.windows.net",
51+
"registry_discovery_endpoint": "https://eastus.api.azureml.ms/",
52+
},
53+
"AzureChinaCloud": {
54+
"portal_endpoint": "https://portal.azure.cn/",
55+
"resource_manager_endpoint": "https://management.chinacloudapi.cn/",
56+
"active_directory_endpoint": "https://login.chinacloudapi.cn/",
57+
"aml_resource_endpoint": "https://ml.azure.cn/",
58+
"storage_suffix": "core.chinacloudapi.cn",
59+
"registry_discovery_endpoint": "https://chinaeast2.api.ml.azure.cn/",
60+
},
61+
"AzureUSGovernment": {
62+
"portal_endpoint": "https://portal.azure.us/",
63+
"resource_manager_endpoint": "https://management.usgovcloudapi.net/",
64+
"active_directory_endpoint": "https://login.microsoftonline.us/",
65+
"aml_resource_endpoint": "https://ml.azure.us/",
66+
"storage_suffix": "core.usgovcloudapi.net",
67+
"registry_discovery_endpoint": "https://usgovarizona.api.ml.azure.us/",
68+
},
69+
}
70+
71+
72+
class AzureEnvironmentClient:
73+
DEFAULT_API_VERSION: Final[str] = "2019-05-01"
74+
DEFAULT_AZURE_CLOUD_NAME: Final[str] = _DEFAULT_AZURE_ENV_NAME
75+
76+
def __init__(self, *, base_url: Optional[str] = None, **kwargs: Any) -> None:
77+
base_url = base_url if base_url is not None else AzureEnvironmentClient.get_default_metadata_url()
78+
79+
config: Configuration = kwargs.pop("config", Configuration(**kwargs))
80+
if config.retry_policy is None:
81+
config.retry_policy = AsyncRetryPolicy(**kwargs)
82+
if config.proxy_policy is None and "proxy" in kwargs:
83+
config.proxy_policy = ProxyPolicy(proxies={"http": kwargs["proxy"], "https": kwargs["proxy"]})
84+
85+
self._async_client = AsyncPipelineClient(base_url, config=config, **kwargs)
86+
87+
async def get_default_cloud_name_async(self, *, update_cached: bool = True) -> str:
88+
current_cloud_env = os.getenv(_ENV_DEFAULT_CLOUD_NAME)
89+
if current_cloud_env is not None:
90+
return current_cloud_env
91+
92+
arm_metadata_url = os.getenv(_ENV_ARM_CLOUD_METADATA_URL)
93+
if arm_metadata_url is None:
94+
return _DEFAULT_AZURE_ENV_NAME
95+
96+
# load clouds from metadata url
97+
clouds = await self.get_clouds_async(metadata_url=arm_metadata_url, update_cached=update_cached)
98+
matched = next(filter(lambda t: t[1]["resource_manager_endpoint"] in arm_metadata_url, clouds.items()), None)
99+
if matched is None:
100+
return _DEFAULT_AZURE_ENV_NAME
101+
102+
os.environ[_ENV_DEFAULT_CLOUD_NAME] = matched[0]
103+
return matched[0]
104+
105+
async def get_cloud_async(self, name: str, *, update_cached: bool = True) -> Optional[AzureEnvironmentMetadata]:
106+
default_endpoint: Optional[str]
107+
108+
def case_insensitive_match(d: Mapping[str, Any], key: str) -> Optional[Any]:
109+
key = key.strip().lower()
110+
return next((v for k,v in d.items() if k.strip().lower() == key), None)
111+
112+
async with _ASYNC_LOCK:
113+
cloud = _KNOWN_AZURE_ENVIRONMENTS.get(name) or case_insensitive_match(_KNOWN_AZURE_ENVIRONMENTS, name)
114+
if cloud:
115+
return cloud
116+
default_endpoint = (_KNOWN_AZURE_ENVIRONMENTS
117+
.get(_DEFAULT_AZURE_ENV_NAME, {})
118+
.get("resource_manager_endpoint"))
119+
120+
metadata_url = self.get_default_metadata_url(default_endpoint)
121+
clouds = await self.get_clouds_async(metadata_url=metadata_url, update_cached=update_cached)
122+
cloud_metadata = clouds.get(name) or case_insensitive_match(clouds, name)
123+
124+
return cloud_metadata
125+
126+
async def get_clouds_async(
127+
self,
128+
*,
129+
metadata_url: Optional[str] = None,
130+
update_cached: bool = True
131+
) -> Mapping[str, AzureEnvironmentMetadata]:
132+
metadata_url = metadata_url or self.get_default_metadata_url()
133+
134+
clouds: Mapping[str, AzureEnvironmentMetadata]
135+
async with self._async_client.send_request(HttpRequest("GET", metadata_url)) as response: # type: ignore
136+
response.raise_for_status()
137+
clouds = await self._parse_cloud_endpoints_async(response.json())
138+
139+
if update_cached:
140+
async with _ASYNC_LOCK:
141+
recursive_update(_KNOWN_AZURE_ENVIRONMENTS, clouds)
142+
return clouds
143+
144+
async def close(self) -> None:
145+
await self._async_client.close()
146+
147+
@staticmethod
148+
def get_default_metadata_url(default_endpoint: Optional[str] = None) -> str:
149+
default_endpoint = default_endpoint or "https://management.azure.com/"
150+
metadata_url = os.getenv(
151+
_ENV_ARM_CLOUD_METADATA_URL,
152+
f"{default_endpoint}metadata/endpoints?api-version={AzureEnvironmentClient.DEFAULT_API_VERSION}")
153+
return metadata_url
154+
155+
@staticmethod
156+
async def _get_registry_discovery_url_async(cloud_name: str, cloud_suffix: str) -> str:
157+
async with _ASYNC_LOCK:
158+
discovery_url = _KNOWN_AZURE_ENVIRONMENTS.get(cloud_name, {}).get("registry_discovery_endpoint")
159+
if discovery_url:
160+
return discovery_url
161+
162+
discovery_url = os.getenv(_ENV_REGISTRY_DISCOVERY_URL)
163+
if discovery_url is not None:
164+
return discovery_url
165+
166+
region = os.getenv(_ENV_REGISTRY_DISCOVERY_REGION, _DEFAULT_REGISTRY_DISCOVERY_REGION)
167+
return f"https://{cloud_name.lower()}{region}.api.ml.azure.{cloud_suffix}/"
168+
169+
@staticmethod
170+
async def _parse_cloud_endpoints_async(data: Any) -> Mapping[str, AzureEnvironmentMetadata]:
171+
# If there is only one cloud, you will get a dict, otherwise a list of dicts
172+
cloud_data: Sequence[Mapping[str, Any]] = data if not isinstance(data, dict) else [data]
173+
clouds: Dict[str, AzureEnvironmentMetadata] = {}
174+
175+
def append_trailing_slash(url: str) -> str:
176+
return url if url.endswith("/") else f"{url}/"
177+
178+
for cloud in cloud_data:
179+
try:
180+
name: str = cloud["name"]
181+
portal_endpoint: str = cloud["portal"]
182+
cloud_suffix = ".".join(portal_endpoint.split(".")[2:]).replace("/", "")
183+
discovery_url = await AzureEnvironmentClient._get_registry_discovery_url_async(name, cloud_suffix)
184+
clouds[name] = {
185+
"portal_endpoint": append_trailing_slash(portal_endpoint),
186+
"resource_manager_endpoint": append_trailing_slash(cloud["resourceManager"]),
187+
"active_directory_endpoint": append_trailing_slash(cloud["authentication"]["loginEndpoint"]),
188+
"aml_resource_endpoint": append_trailing_slash(f"https://ml.azure.{cloud_suffix}/"),
189+
"storage_suffix": cloud["suffixes"]["storage"],
190+
"registry_discovery_endpoint": append_trailing_slash(discovery_url),
191+
}
192+
except KeyError:
193+
continue
194+
195+
return clouds
196+
197+
198+
def recursive_update(d: Dict, u: Mapping) -> None:
199+
"""Recursively update a dictionary.
200+
201+
:param Dict d: The dictionary to update.
202+
:param Mapping u: The mapping to update from.
203+
"""
204+
for k, v in u.items():
205+
if isinstance(v, Dict) and k in d:
206+
recursive_update(d[k], v)
207+
else:
208+
d[k] = v

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_azure/_token_manager.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import time
77
import inspect
8-
from typing import cast, Optional, Union, Any
8+
from typing import cast, Optional, Union
99

1010
from azure.core.credentials import TokenCredential, AccessToken
1111
from azure.identity import AzureCliCredential, DefaultAzureCredential, ManagedIdentityCredential
@@ -71,7 +71,7 @@ def get_aad_credential(self) -> Union[DefaultAzureCredential, ManagedIdentityCre
7171
# Fall back to using the parent implementation
7272
return super().get_aad_credential()
7373

74-
def get_token(self, *scopes: str, claims: Union[str, None] = None, tenant_id: Union[str, None] = None, enable_cae: bool = False, **kwargs: Any) -> AccessToken:
74+
def get_token(self) -> AccessToken:
7575
"""Get the API token. If the token is not available or has expired, refresh the token.
7676
7777
:return: API token
@@ -82,9 +82,9 @@ def get_token(self, *scopes: str, claims: Union[str, None] = None, tenant_id: Un
8282
access_token = credential.get_token(self.token_scope)
8383
self._update_token(access_token)
8484

85-
return self.token # check for none is hidden in the _token_needs_update method
85+
return cast(AccessToken, self.token) # check for none is hidden in the _token_needs_update method
8686

87-
async def get_token_async(self) -> str:
87+
async def get_token_async(self) -> AccessToken:
8888
"""Get the API token asynchronously. If the token is not available or has expired, refresh it.
8989
9090
:return: API token
@@ -99,7 +99,7 @@ async def get_token_async(self) -> str:
9999
access_token = get_token_method
100100
self._update_token(access_token)
101101

102-
return cast(str, self.token) # check for none is hidden in the _token_needs_update method
102+
return cast(AccessToken, self.token) # check for none is hidden in the _token_needs_update method
103103

104104
def _token_needs_update(self) -> bool:
105105
current_time = time.time()

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_constants.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,12 @@ class _AggregationType(enum.Enum):
8080
SUM = "sum"
8181
CUSTOM = "custom"
8282

83-
class TokenScope(enum.Enum):
83+
class TokenScope(str, enum.Enum):
8484
"""Defines the scope of the token used to access Azure resources."""
8585

8686
DEFAULT_AZURE_MANAGEMENT = "https://management.azure.com/.default"
87-
COGNITIVE_SERVICES = "https://cognitiveservices.azure.com/.default"
87+
COGNITIVE_SERVICES_MANAGEMENT = "https://cognitiveservices.azure.com/.default"
88+
AZURE_ML = "https://ml.azure.com/.default"
8889

8990

9091
DEFAULT_EVALUATION_RESULTS_FILE_NAME = "evaluation_results.json"
@@ -112,4 +113,4 @@ class TokenScope(enum.Enum):
112113

113114
AOAI_COLUMN_NAME = "aoai"
114115
DEFAULT_OAI_EVAL_RUN_NAME = "AI_SDK_EVAL_RUN"
115-
DEFAULT_AOAI_API_VERSION = "2025-04-01-preview" # Unfortunately relying on preview version for now.
116+
DEFAULT_AOAI_API_VERSION = "2025-04-01-preview" # Unfortunately relying on preview version for now.

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def _log_metrics_and_instance_results_onedp(
141141
from azure.ai.evaluation._common import EvaluationServiceOneDPClient, EvaluationUpload
142142

143143
credentials = AzureMLTokenManager(
144-
TokenScope.COGNITIVE_SERVICES.value, LOGGER, credential=kwargs.get("credential")
144+
TokenScope.COGNITIVE_SERVICES_MANAGEMENT.value, LOGGER, credential=kwargs.get("credential")
145145
)
146146
client = EvaluationServiceOneDPClient(
147147
endpoint=project_url,

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_common/_base_prompty_eval.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
import os
88
from typing import Dict, TypeVar, Union
99

10-
from azure.ai.evaluation._legacy.prompty import AsyncPrompty
10+
if os.getenv("AI_EVALS_USE_PF_PROMPTY", "false").lower() == "true":
11+
from promptflow.core._flow import AsyncPrompty
12+
else:
13+
from azure.ai.evaluation._legacy.prompty import AsyncPrompty
1114
from typing_extensions import override
1215

1316
from azure.ai.evaluation._common.constants import PROMPT_BASED_REASON_EVALUATORS

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_legacy/_adapters/_flows.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
try:
9-
from promptflow._sdk.entities._flows import AsyncPrompty as _AsyncPrompty
9+
from promptflow.core._flow import AsyncPrompty as _AsyncPrompty
1010
from promptflow._sdk.entities._flows import FlexFlow as _FlexFlow
1111
from promptflow._sdk.entities._flows.dag import Flow as _Flow
1212
except ImportError:

0 commit comments

Comments
 (0)