Skip to content

Commit 287459c

Browse files
nagkumar91Nagkumar Arkalgudkdestin
authored
Bugfix/token expired (Azure#37763)
* Update task_query_response.prompty remove required keys * Update task_simulate.prompty * Update task_query_response.prompty * Update task_simulate.prompty * Add a async method to refresh token * Update changelog * Update _identity_manager.py * Fix lint issue * adding an abstractmethod to help mypy * black formatting * Update sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_model_tools/_identity_manager.py Co-authored-by: kdestin <[email protected]> * Update sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_model_tools/_identity_manager.py Co-authored-by: kdestin <[email protected]> * Trying a fix for mypy --------- Co-authored-by: Nagkumar Arkalgud <[email protected]> Co-authored-by: kdestin <[email protected]>
1 parent 1be56df commit 287459c

File tree

3 files changed

+42
-2
lines changed

3 files changed

+42
-2
lines changed

sdk/evaluation/azure-ai-evaluation/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
- `credential` is now required to be passed in for all content safety evaluators and `ProtectedMaterialsEvaluator`. `DefaultAzureCredential` will no longer be chosen if a credential is not passed.
1111

1212
### Bugs Fixed
13+
- Adversarial Conversation simulations would fail with `Forbidden`. Added logic to re-fetch token in the exponential retry logic to retrive RAI Service response.
1314

1415
### Other Changes
1516

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_model_tools/_identity_manager.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
# ---------------------------------------------------------
44

55
import asyncio
6+
import inspect
67
import logging
78
import os
89
import time
910
from abc import ABC, abstractmethod
1011
from enum import Enum
1112
from typing import Optional, Union
1213

13-
from azure.core.credentials import TokenCredential
14+
from azure.core.credentials import TokenCredential, AccessToken
1415
from azure.identity import DefaultAzureCredential, ManagedIdentityCredential
1516

1617
AZURE_TOKEN_REFRESH_INTERVAL = 600 # seconds
@@ -87,6 +88,14 @@ def get_token(self) -> str:
8788
:rtype: str
8889
"""
8990

91+
@abstractmethod
92+
async def get_token_async(self) -> str:
93+
"""Async method to get the API token. Subclasses should implement this method.
94+
95+
:return: API token
96+
:rtype: str
97+
"""
98+
9099

91100
class ManagedIdentityAPITokenManager(APITokenManager):
92101
"""API Token Manager for Azure Managed Identity
@@ -127,6 +136,31 @@ def get_token(self) -> str:
127136

128137
return self.token
129138

139+
async def get_token_async(self) -> str:
140+
"""Get the API token synchronously. If the token is not available or has expired, refresh it.
141+
142+
:return: API token
143+
:rtype: str
144+
"""
145+
if (
146+
self.token is None
147+
or self.last_refresh_time is None
148+
or time.time() - self.last_refresh_time > AZURE_TOKEN_REFRESH_INTERVAL
149+
):
150+
self.last_refresh_time = time.time()
151+
get_token_method = self.credential.get_token(self.token_scope.value)
152+
if inspect.isawaitable(get_token_method):
153+
# If it's awaitable, await it
154+
token_response: AccessToken = await get_token_method
155+
else:
156+
# Otherwise, call it synchronously
157+
token_response = get_token_method
158+
159+
self.token = token_response.token
160+
self.logger.info("Refreshed Azure endpoint token.")
161+
162+
return self.token
163+
130164

131165
class PlainTokenManager(APITokenManager):
132166
"""Plain API Token Manager

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_model_tools/_proxy_completion_model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ async def request_api(
172172
}
173173
# add all additional headers
174174
headers.update(self.additional_headers) # type: ignore[arg-type]
175-
176175
params = {}
177176
if self.api_version:
178177
params["api-version"] = self.api_version
@@ -214,6 +213,12 @@ async def request_api(
214213
time.sleep(15)
215214

216215
async with get_async_http_client().with_policies(retry_policy=retry_policy) as exp_retry_client:
216+
token = await self.token_manager.get_token_async()
217+
proxy_headers = {
218+
"Authorization": f"Bearer {token}",
219+
"Content-Type": "application/json",
220+
"User-Agent": USER_AGENT,
221+
}
217222
response = await exp_retry_client.get( # pylint: disable=too-many-function-args,unexpected-keyword-arg
218223
self.result_url, headers=proxy_headers
219224
)

0 commit comments

Comments
 (0)