|
13 | 13 | BadRequestError, |
14 | 14 | ) |
15 | 15 | from openai.types.chat import ChatCompletion |
16 | | -from azure.ai.inference import ChatCompletionsClient |
17 | | -from azure.ai.inference.aio import ChatCompletionsClient as AsyncChatCompletionsClient |
18 | | -from azure.core.credentials import AzureKeyCredential |
19 | | -from azure.core.exceptions import ( |
20 | | - HttpResponseError, |
21 | | - ServiceRequestError, |
22 | | - ClientAuthenticationError, |
23 | | -) |
24 | 16 |
|
25 | 17 |
|
26 | 18 | class LLMProvider_A2VYBG(Enum): |
@@ -97,6 +89,11 @@ def get_client_openai_8e8a360e_3f7f_4cf9_ba80_8cb239e897d2( |
97 | 89 | api_key="dummy", |
98 | 90 | base_url="http://privatemode-proxy:8080/v1", |
99 | 91 | ) |
| 92 | + elif CLIENT_TYPE_A2VYBG == LLMProvider_A2VYBG.AZURE_FOUNDRY.value: |
| 93 | + return AsyncOpenAI( |
| 94 | + api_key=API_KEY_A2VYBG, |
| 95 | + base_url=API_BASE_A2VYBG, |
| 96 | + ) |
100 | 97 |
|
101 | 98 | if CLIENT_TYPE_A2VYBG == LLMProvider_A2VYBG.AZURE.value and ( |
102 | 99 | azure_endpoint is None or api_version is None |
@@ -254,29 +251,19 @@ async def get_chat_completion_async_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c( |
254 | 251 | **kwargs, |
255 | 252 | ) -> ChatCompletion: |
256 | 253 | completion = None |
257 | | - if CLIENT_TYPE_A2VYBG == LLMProvider_A2VYBG.AZURE_FOUNDRY.value: |
258 | | - client = await get_client_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c( |
259 | | - use_async=True, |
260 | | - api_key=api_key, |
261 | | - azure_endpoint=azure_endpoint, |
262 | | - ) |
263 | | - completion = await client.complete( |
264 | | - messages=messages, response_format="json_object", **kwargs |
265 | | - ) |
266 | | - else: |
267 | | - client = get_client_openai_8e8a360e_3f7f_4cf9_ba80_8cb239e897d2( |
268 | | - use_async=True, |
269 | | - api_key=api_key, |
270 | | - azure_endpoint=azure_endpoint, |
271 | | - api_version=api_version, |
272 | | - prevent_cached_client=close_after, |
273 | | - ) |
274 | | - completion = await client.chat.completions.create( |
275 | | - model=model, |
276 | | - messages=messages, |
277 | | - response_format={"type": "json_object"}, |
278 | | - **kwargs, |
279 | | - ) |
| 254 | + client = get_client_openai_8e8a360e_3f7f_4cf9_ba80_8cb239e897d2( |
| 255 | + use_async=True, |
| 256 | + api_key=api_key, |
| 257 | + azure_endpoint=azure_endpoint, |
| 258 | + api_version=api_version, |
| 259 | + prevent_cached_client=close_after, |
| 260 | + ) |
| 261 | + completion = await client.chat.completions.create( |
| 262 | + model=model, |
| 263 | + messages=messages, |
| 264 | + response_format={"type": "json_object"}, |
| 265 | + **kwargs, |
| 266 | + ) |
280 | 267 |
|
281 | 268 | if close_after: |
282 | 269 | result = client.close() |
@@ -363,102 +350,3 @@ async def get_llm_response(record: dict, cached_records: dict): |
363 | 350 | print(m, flush=True) |
364 | 351 | cached_records[curr_running_id] = {"result": m} |
365 | 352 | return {"result": m} |
366 | | - |
367 | | - |
368 | | -# ------------------ AZURE FOUNDRY------------------ |
369 | | - |
370 | | - |
371 | | -async def get_client_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c( |
372 | | - use_async: bool, |
373 | | - api_key: str, |
374 | | - azure_endpoint: Optional[str] = None, |
375 | | - check_valid: bool = True, |
376 | | - prevent_cached_client: bool = True, |
377 | | -) -> Union[ChatCompletionsClient, AsyncChatCompletionsClient]: |
378 | | - |
379 | | - global CLIENT_LOOKUP_A2VYBG |
380 | | - |
381 | | - if CLIENT_TYPE_A2VYBG == LLMProvider_A2VYBG.AZURE_FOUNDRY.value and ( |
382 | | - azure_endpoint is None |
383 | | - ): |
384 | | - raise ValueError("azure_endpoint must be set for Azure Foundry") |
385 | | - |
386 | | - # tuples can be used as dict keys, primitive datatype comparison works flawless, caution with objects though! |
387 | | - config = (CLIENT_TYPE_A2VYBG, use_async, api_key, azure_endpoint) |
388 | | - use_cache = MAX_CACHED_CLIENTS_A2VYBG != 0 and not prevent_cached_client |
389 | | - if use_cache and config in CLIENT_LOOKUP_A2VYBG: |
390 | | - if check_valid: |
391 | | - exception = await __is_client_valid_ex_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c( |
392 | | - CLIENT_LOOKUP_A2VYBG[config][0] |
393 | | - ) |
394 | | - if exception is not None: |
395 | | - raise exception |
396 | | - |
397 | | - CLIENT_LOOKUP_A2VYBG[config] = (CLIENT_LOOKUP_A2VYBG[config][0], time.time()) |
398 | | - |
399 | | - return CLIENT_LOOKUP_A2VYBG[config][0] |
400 | | - |
401 | | - else: |
402 | | - if use_cache and len(CLIENT_LOOKUP_A2VYBG) >= MAX_CACHED_CLIENTS_A2VYBG: |
403 | | - # remove oldest client |
404 | | - tmp = sorted( |
405 | | - CLIENT_LOOKUP_A2VYBG.items(), key=lambda x: x[1][1], reverse=True |
406 | | - ) |
407 | | - (client, _) = tmp.pop() |
408 | | - client.close() |
409 | | - CLIENT_LOOKUP_A2VYBG = dict(tmp) |
410 | | - |
411 | | - client = __create_client_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c( |
412 | | - use_async=use_async, api_key=api_key, azure_endpoint=azure_endpoint |
413 | | - ) |
414 | | - |
415 | | - # test client with api key |
416 | | - if check_valid: |
417 | | - exception = await __is_client_valid_ex_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c( |
418 | | - client |
419 | | - ) |
420 | | - if exception is not None: |
421 | | - raise exception |
422 | | - |
423 | | - if use_cache: |
424 | | - CLIENT_LOOKUP_A2VYBG[config] = (client, time.time()) |
425 | | - return client |
426 | | - |
427 | | - |
428 | | -def __create_client_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c( |
429 | | - use_async: bool, |
430 | | - api_key: str, |
431 | | - azure_endpoint: Optional[str] = None, |
432 | | -) -> Union[ChatCompletionsClient, AsyncChatCompletionsClient]: |
433 | | - |
434 | | - if use_async: |
435 | | - client = AsyncChatCompletionsClient( |
436 | | - endpoint=azure_endpoint, credential=AzureKeyCredential(api_key) |
437 | | - ) |
438 | | - else: |
439 | | - client = ChatCompletionsClient( |
440 | | - endpoint=azure_endpoint, credential=AzureKeyCredential(api_key) |
441 | | - ) |
442 | | - |
443 | | - return client |
444 | | - |
445 | | - |
446 | | -async def __is_client_valid_ex_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c( |
447 | | - client: AsyncChatCompletionsClient, |
448 | | - tries: int = 3, |
449 | | -) -> Union[Exception, None]: |
450 | | - for i in range(tries + 1): |
451 | | - try: |
452 | | - await client.get_model_info() |
453 | | - return None |
454 | | - except ( |
455 | | - HttpResponseError, |
456 | | - ServiceRequestError, |
457 | | - ClientAuthenticationError, |
458 | | - Exception, |
459 | | - ) as e: |
460 | | - if i < tries: |
461 | | - await asyncio.sleep(0.05) |
462 | | - continue |
463 | | - return ValueError("Invalid Azure client: " + str(e)) |
464 | | - return None |
0 commit comments