Skip to content

Commit ad42cfc

Browse files
authored
Foundry for access env change (#354)
1 parent 00873d8 commit ad42cfc

File tree

2 files changed

+22
-132
lines changed

2 files changed

+22
-132
lines changed

controller/attribute/llm_response_tmpl.py

Lines changed: 18 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,6 @@
1313
BadRequestError,
1414
)
1515
from openai.types.chat import ChatCompletion
16-
from azure.ai.inference import ChatCompletionsClient
17-
from azure.ai.inference.aio import ChatCompletionsClient as AsyncChatCompletionsClient
18-
from azure.core.credentials import AzureKeyCredential
19-
from azure.core.exceptions import (
20-
HttpResponseError,
21-
ServiceRequestError,
22-
ClientAuthenticationError,
23-
)
2416

2517

2618
class LLMProvider_A2VYBG(Enum):
@@ -97,6 +89,11 @@ def get_client_openai_8e8a360e_3f7f_4cf9_ba80_8cb239e897d2(
9789
api_key="dummy",
9890
base_url="http://privatemode-proxy:8080/v1",
9991
)
92+
elif CLIENT_TYPE_A2VYBG == LLMProvider_A2VYBG.AZURE_FOUNDRY.value:
93+
return AsyncOpenAI(
94+
api_key=API_KEY_A2VYBG,
95+
base_url=API_BASE_A2VYBG,
96+
)
10097

10198
if CLIENT_TYPE_A2VYBG == LLMProvider_A2VYBG.AZURE.value and (
10299
azure_endpoint is None or api_version is None
@@ -254,29 +251,19 @@ async def get_chat_completion_async_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c(
254251
**kwargs,
255252
) -> ChatCompletion:
256253
completion = None
257-
if CLIENT_TYPE_A2VYBG == LLMProvider_A2VYBG.AZURE_FOUNDRY.value:
258-
client = await get_client_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c(
259-
use_async=True,
260-
api_key=api_key,
261-
azure_endpoint=azure_endpoint,
262-
)
263-
completion = await client.complete(
264-
messages=messages, response_format="json_object", **kwargs
265-
)
266-
else:
267-
client = get_client_openai_8e8a360e_3f7f_4cf9_ba80_8cb239e897d2(
268-
use_async=True,
269-
api_key=api_key,
270-
azure_endpoint=azure_endpoint,
271-
api_version=api_version,
272-
prevent_cached_client=close_after,
273-
)
274-
completion = await client.chat.completions.create(
275-
model=model,
276-
messages=messages,
277-
response_format={"type": "json_object"},
278-
**kwargs,
279-
)
254+
client = get_client_openai_8e8a360e_3f7f_4cf9_ba80_8cb239e897d2(
255+
use_async=True,
256+
api_key=api_key,
257+
azure_endpoint=azure_endpoint,
258+
api_version=api_version,
259+
prevent_cached_client=close_after,
260+
)
261+
completion = await client.chat.completions.create(
262+
model=model,
263+
messages=messages,
264+
response_format={"type": "json_object"},
265+
**kwargs,
266+
)
280267

281268
if close_after:
282269
result = client.close()
@@ -363,102 +350,3 @@ async def get_llm_response(record: dict, cached_records: dict):
363350
print(m, flush=True)
364351
cached_records[curr_running_id] = {"result": m}
365352
return {"result": m}
366-
367-
368-
# ------------------ AZURE FOUNDRY------------------
369-
370-
371-
async def get_client_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c(
372-
use_async: bool,
373-
api_key: str,
374-
azure_endpoint: Optional[str] = None,
375-
check_valid: bool = True,
376-
prevent_cached_client: bool = True,
377-
) -> Union[ChatCompletionsClient, AsyncChatCompletionsClient]:
378-
379-
global CLIENT_LOOKUP_A2VYBG
380-
381-
if CLIENT_TYPE_A2VYBG == LLMProvider_A2VYBG.AZURE_FOUNDRY.value and (
382-
azure_endpoint is None
383-
):
384-
raise ValueError("azure_endpoint must be set for Azure Foundry")
385-
386-
# tuples can be used as dict keys, primitive datatype comparison works flawless, caution with objects though!
387-
config = (CLIENT_TYPE_A2VYBG, use_async, api_key, azure_endpoint)
388-
use_cache = MAX_CACHED_CLIENTS_A2VYBG != 0 and not prevent_cached_client
389-
if use_cache and config in CLIENT_LOOKUP_A2VYBG:
390-
if check_valid:
391-
exception = await __is_client_valid_ex_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c(
392-
CLIENT_LOOKUP_A2VYBG[config][0]
393-
)
394-
if exception is not None:
395-
raise exception
396-
397-
CLIENT_LOOKUP_A2VYBG[config] = (CLIENT_LOOKUP_A2VYBG[config][0], time.time())
398-
399-
return CLIENT_LOOKUP_A2VYBG[config][0]
400-
401-
else:
402-
if use_cache and len(CLIENT_LOOKUP_A2VYBG) >= MAX_CACHED_CLIENTS_A2VYBG:
403-
# remove oldest client
404-
tmp = sorted(
405-
CLIENT_LOOKUP_A2VYBG.items(), key=lambda x: x[1][1], reverse=True
406-
)
407-
(client, _) = tmp.pop()
408-
client.close()
409-
CLIENT_LOOKUP_A2VYBG = dict(tmp)
410-
411-
client = __create_client_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c(
412-
use_async=use_async, api_key=api_key, azure_endpoint=azure_endpoint
413-
)
414-
415-
# test client with api key
416-
if check_valid:
417-
exception = await __is_client_valid_ex_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c(
418-
client
419-
)
420-
if exception is not None:
421-
raise exception
422-
423-
if use_cache:
424-
CLIENT_LOOKUP_A2VYBG[config] = (client, time.time())
425-
return client
426-
427-
428-
def __create_client_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c(
429-
use_async: bool,
430-
api_key: str,
431-
azure_endpoint: Optional[str] = None,
432-
) -> Union[ChatCompletionsClient, AsyncChatCompletionsClient]:
433-
434-
if use_async:
435-
client = AsyncChatCompletionsClient(
436-
endpoint=azure_endpoint, credential=AzureKeyCredential(api_key)
437-
)
438-
else:
439-
client = ChatCompletionsClient(
440-
endpoint=azure_endpoint, credential=AzureKeyCredential(api_key)
441-
)
442-
443-
return client
444-
445-
446-
async def __is_client_valid_ex_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c(
447-
client: AsyncChatCompletionsClient,
448-
tries: int = 3,
449-
) -> Union[Exception, None]:
450-
for i in range(tries + 1):
451-
try:
452-
await client.get_model_info()
453-
return None
454-
except (
455-
HttpResponseError,
456-
ServiceRequestError,
457-
ClientAuthenticationError,
458-
Exception,
459-
) as e:
460-
if i < tries:
461-
await asyncio.sleep(0.05)
462-
continue
463-
return ValueError("Invalid Azure client: " + str(e))
464-
return None

controller/attribute/util.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def test_openai_llm_connection(api_key: str, model: str, is_o_series: bool = Fal
118118
return response.json()["choices"][0]["message"]["content"]
119119

120120

121-
def test_azure_foundry_llm_connection(api_key: str, base_endpoint: str):
121+
def test_azure_foundry_llm_connection(api_key: str, base_endpoint: str, model: str):
122122
# more here: https://learn.microsoft.com/en-us/rest/api/aifoundry/modelinference/
123123
base_endpoint = base_endpoint.rstrip("/")
124124
final_endpoint = f"{base_endpoint}/chat/completions"
@@ -132,6 +132,7 @@ def test_azure_foundry_llm_connection(api_key: str, base_endpoint: str):
132132
{"role": "user", "content": [{"type": "text", "text": "only say 'hello'"}]},
133133
],
134134
"max_tokens": 5,
135+
"model": model,
135136
}
136137

137138
response = requests.post(final_endpoint, headers=headers, json=payload)
@@ -243,13 +244,14 @@ def validate_llm_config(llm_config: Dict[str, Any]):
243244
api_key=llm_config["apiKey"],
244245
model=llm_config["model"],
245246
base_endpoint=llm_config["apiBase"],
246-
api_version=llm_config["apiVersion"],
247+
api_version=llm_config.get("apiVersion"),
247248
is_o_series=llm_config.get("openAioSeries", False),
248249
)
249250
elif llm_config["llmIdentifier"] == enums.LLMProvider.AZURE_FOUNDRY.value:
250251
test_azure_foundry_llm_connection(
251252
api_key=llm_config["apiKey"],
252253
base_endpoint=llm_config["apiBase"],
254+
model=llm_config["model"],
253255
)
254256
elif llm_config["llmIdentifier"] == enums.LLMProvider.PRIVATEMODE_AI.value:
255257
test_privatemode_ai_llm_connection(

0 commit comments

Comments
 (0)