33from typing import Any , Optional , Union , List , Dict
44from enum import Enum
55import asyncio
6-
76from openai import OpenAI , AsyncOpenAI , AzureOpenAI , AsyncAzureOpenAI
87from openai import (
98 AuthenticationError ,
1413 BadRequestError ,
1514)
1615from openai .types .chat import ChatCompletion
16+ from azure .ai .inference import ChatCompletionsClient
17+ from azure .ai .inference .aio import ChatCompletionsClient as AsyncChatCompletionsClient
18+ from azure .core .credentials import AzureKeyCredential
19+ from azure .core .exceptions import (
20+ HttpResponseError ,
21+ ServiceRequestError ,
22+ ClientAuthenticationError ,
23+ )
1724
1825
1926class LLMProvider_A2VYBG (Enum ):
2027 OPEN_AI = "Open AI"
2128 OPEN_SOURCE = "Open-Source"
2229 AZURE = "Azure"
30+ AZURE_FOUNDRY = "Azure Foundry"
2331
2432
2533# OpenAI migration guides
@@ -42,7 +50,6 @@ class LLMProvider_A2VYBG(Enum):
4250CACHE_ACCESS_LINK_A2VYBG = "@@CACHE_ACCESS_LINK@@"
4351CACHE_FILE_UPLOAD_LINK_A2VYBG = "@@CACHE_FILE_UPLOAD_LINK@@"
4452LLM_KWARGS_A2VYBG = {
45- "response_format" : {"type" : "json_object" },
4653 "stream" : False ,
4754 # fmt:off
4855 "stop" : json .loads ('@@STOP_SEQUENCE@@' ),
@@ -67,37 +74,7 @@ class LLMProvider_A2VYBG(Enum):
6774# azure_endpoint = api_base (before 1.0) - basically the link to the api
6875
6976
70- def test_client_model_2c6ecfb1_9bce_4e89_80c8_cbc4e3fca9e5 (
71- client : Union [OpenAI , AsyncOpenAI , AzureOpenAI , AsyncAzureOpenAI ], model : str
72- ):
73- if __is_client_valid_ex_8840b3a8_92d2_4526_b054_3b83c5cccb5c (client ) is not None :
74- print (
75- "Error: Invalid OpenAI client config (api_key, api_version or endpoint)" ,
76- flush = True ,
77- )
78- return False
79-
80- try :
81- client .chat .completions .create (
82- model = model ,
83- messages = [
84- {
85- "role" : "user" ,
86- "content" : "A" ,
87- }
88- ],
89- stream = False ,
90- temperature = 1 ,
91- max_tokens = 1 ,
92- )
93- except Exception as e :
94- print ("Error: Test chat completion failed" , flush = True )
95- print (e , flush = True )
96- return False
97- return True
98-
99-
100- def get_client_8e8a360e_3f7f_4cf9_ba80_8cb239e897d2 (
77+ def get_client_openai_8e8a360e_3f7f_4cf9_ba80_8cb239e897d2 (
10178 use_async : bool ,
10279 api_key : str ,
10380 azure_endpoint : Optional [str ] = None ,
@@ -117,8 +94,10 @@ def get_client_8e8a360e_3f7f_4cf9_ba80_8cb239e897d2(
11794 use_cache = MAX_CACHED_CLIENTS_A2VYBG != 0 and not prevent_cached_client
11895 if use_cache and config in CLIENT_LOOKUP_A2VYBG :
11996 if check_valid :
120- exception = __is_client_valid_ex_8840b3a8_92d2_4526_b054_3b83c5cccb5c (
121- CLIENT_LOOKUP_A2VYBG [config ][0 ]
97+ exception = (
98+ __is_client_valid_ex_openai_8840b3a8_92d2_4526_b054_3b83c5cccb5c (
99+ CLIENT_LOOKUP_A2VYBG [config ][0 ]
100+ )
122101 )
123102 if exception is not None :
124103 raise exception
@@ -136,14 +115,14 @@ def get_client_8e8a360e_3f7f_4cf9_ba80_8cb239e897d2(
136115 client .close ()
137116 CLIENT_LOOKUP_A2VYBG = dict (tmp )
138117
139- client = __create_client_bf47529a_75f7_498b_a091_4e7d52d35b6b (
118+ client = __create_client_openai_bf47529a_75f7_498b_a091_4e7d52d35b6b (
140119 use_async , api_key , azure_endpoint , api_version
141120 )
142121
143122 # test client with api key
144123 if check_valid :
145- exception = __is_client_valid_ex_8840b3a8_92d2_4526_b054_3b83c5cccb5c (
146- client
124+ exception = (
125+ __is_client_valid_ex_openai_8840b3a8_92d2_4526_b054_3b83c5cccb5c ( client )
147126 )
148127 if exception is not None :
149128 raise exception
@@ -153,7 +132,7 @@ def get_client_8e8a360e_3f7f_4cf9_ba80_8cb239e897d2(
153132 return client
154133
155134
156- def __create_client_bf47529a_75f7_498b_a091_4e7d52d35b6b (
135+ def __create_client_openai_bf47529a_75f7_498b_a091_4e7d52d35b6b (
157136 use_async : bool ,
158137 api_key : str ,
159138 azure_endpoint : Optional [str ] = None ,
@@ -185,7 +164,7 @@ def __create_client_bf47529a_75f7_498b_a091_4e7d52d35b6b(
185164 return client
186165
187166
188- def __is_client_valid_ex_8840b3a8_92d2_4526_b054_3b83c5cccb5c (
167+ def __is_client_valid_ex_openai_8840b3a8_92d2_4526_b054_3b83c5cccb5c (
189168 client : Union [OpenAI , AsyncOpenAI , AzureOpenAI , AsyncAzureOpenAI ], tries : int = 3
190169) -> Union [AuthenticationError , Exception , None ]:
191170 i = 0
@@ -251,36 +230,6 @@ def convert_to_string(data):
251230 return str (data )
252231
253232
254- # all work similar but use different classes etc.
255- # note that kwargs is just passed to the openai client so adding unknown kwargs will result in issues
256- # named parameter are NOT considered kwargs, only unknown parameters are kwargs
257- def get_chat_completion_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c (
258- model : str ,
259- messages : List [Dict [str , str ]],
260- api_key : str ,
261- azure_endpoint : Optional [str ] = None ,
262- api_version : Optional [str ] = None ,
263- close_after : bool = False ,
264- ** kwargs ,
265- ) -> ChatCompletion :
266- client = get_client_8e8a360e_3f7f_4cf9_ba80_8cb239e897d2 (
267- use_async = False ,
268- api_key = api_key ,
269- azure_endpoint = azure_endpoint ,
270- api_version = api_version ,
271- prevent_cached_client = close_after ,
272- )
273- completion = client .chat .completions .create (
274- model = model ,
275- messages = messages ,
276- ** kwargs ,
277- )
278- if close_after :
279- client .close ()
280-
281- return completion
282-
283-
284233async def get_chat_completion_async_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c (
285234 model : str ,
286235 messages : List [Dict [str , str ]],
@@ -290,21 +239,32 @@ async def get_chat_completion_async_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c(
290239 close_after : bool = False ,
291240 ** kwargs ,
292241) -> ChatCompletion :
293- client = get_client_8e8a360e_3f7f_4cf9_ba80_8cb239e897d2 (
294- use_async = True ,
295- api_key = api_key ,
296- azure_endpoint = azure_endpoint ,
297- api_version = api_version ,
298- prevent_cached_client = close_after ,
299- )
300- completion = await client .chat .completions .create (
301- model = model ,
302- messages = messages ,
303- ** kwargs ,
304- )
242+ completion = None
243+ if CLIENT_TYPE_A2VYBG == LLMProvider_A2VYBG .AZURE_FOUNDRY .value :
244+ client = await get_client_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c (
245+ use_async = True ,
246+ api_key = api_key ,
247+ azure_endpoint = azure_endpoint ,
248+ )
249+ completion = await client .complete (
250+ messages = messages , response_format = "json_object" , ** kwargs
251+ )
252+ else :
253+ client = get_client_openai_8e8a360e_3f7f_4cf9_ba80_8cb239e897d2 (
254+ use_async = True ,
255+ api_key = api_key ,
256+ azure_endpoint = azure_endpoint ,
257+ api_version = api_version ,
258+ prevent_cached_client = close_after ,
259+ )
260+ completion = await client .chat .completions .create (
261+ model = model ,
262+ messages = messages ,
263+ response_format = {"type" : "json_object" },
264+ ** kwargs ,
265+ )
305266 if close_after :
306267 await client .close ()
307-
308268 return completion
309269
310270
@@ -372,3 +332,102 @@ async def get_llm_response(record: dict, cached_records: dict):
372332 print (m , flush = True )
373333 cached_records [curr_running_id ] = {"result" : m }
374334 return {"result" : m }
335+
336+
337+ # ------------------ AZURE FOUNDRY------------------
338+
339+
340+ async def get_client_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c (
341+ use_async : bool ,
342+ api_key : str ,
343+ azure_endpoint : Optional [str ] = None ,
344+ check_valid : bool = True ,
345+ prevent_cached_client : bool = True ,
346+ ) -> Union [ChatCompletionsClient , AsyncChatCompletionsClient ]:
347+
348+ global CLIENT_LOOKUP_A2VYBG
349+
350+ if CLIENT_TYPE_A2VYBG == LLMProvider_A2VYBG .AZURE_FOUNDRY .value and (
351+ azure_endpoint is None
352+ ):
353+ raise ValueError ("azure_endpoint must be set for Azure Foundry" )
354+
355+ # tuples can be used as dict keys, primitive datatype comparison works flawless, caution with objects though!
356+ config = (CLIENT_TYPE_A2VYBG , use_async , api_key , azure_endpoint )
357+ use_cache = MAX_CACHED_CLIENTS_A2VYBG != 0 and not prevent_cached_client
358+ if use_cache and config in CLIENT_LOOKUP_A2VYBG :
359+ if check_valid :
360+ exception = await __is_client_valid_ex_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c (
361+ CLIENT_LOOKUP_A2VYBG [config ][0 ]
362+ )
363+ if exception is not None :
364+ raise exception
365+
366+ CLIENT_LOOKUP_A2VYBG [config ] = (CLIENT_LOOKUP_A2VYBG [config ][0 ], time .time ())
367+
368+ return CLIENT_LOOKUP_A2VYBG [config ][0 ]
369+
370+ else :
371+ if use_cache and len (CLIENT_LOOKUP_A2VYBG ) >= MAX_CACHED_CLIENTS_A2VYBG :
372+ # remove oldest client
373+ tmp = sorted (
374+ CLIENT_LOOKUP_A2VYBG .items (), key = lambda x : x [1 ][1 ], reverse = True
375+ )
376+ (client , _ ) = tmp .pop ()
377+ client .close ()
378+ CLIENT_LOOKUP_A2VYBG = dict (tmp )
379+
380+ client = __create_client_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c (
381+ use_async = use_async , api_key = api_key , azure_endpoint = azure_endpoint
382+ )
383+
384+ # test client with api key
385+ if check_valid :
386+ exception = await __is_client_valid_ex_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c (
387+ client
388+ )
389+ if exception is not None :
390+ raise exception
391+
392+ if use_cache :
393+ CLIENT_LOOKUP_A2VYBG [config ] = (client , time .time ())
394+ return client
395+
396+
397+ def __create_client_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c (
398+ use_async : bool ,
399+ api_key : str ,
400+ azure_endpoint : Optional [str ] = None ,
401+ ) -> Union [ChatCompletionsClient , AsyncChatCompletionsClient ]:
402+
403+ if use_async :
404+ client = AsyncChatCompletionsClient (
405+ endpoint = azure_endpoint , credential = AzureKeyCredential (api_key )
406+ )
407+ else :
408+ client = ChatCompletionsClient (
409+ endpoint = azure_endpoint , credential = AzureKeyCredential (api_key )
410+ )
411+
412+ return client
413+
414+
415+ async def __is_client_valid_ex_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c (
416+ client : AsyncChatCompletionsClient ,
417+ tries : int = 3 ,
418+ ) -> Union [Exception , None ]:
419+ for i in range (tries + 1 ):
420+ try :
421+ await client .get_model_info ()
422+ return None
423+ except (
424+ HttpResponseError ,
425+ ServiceRequestError ,
426+ ClientAuthenticationError ,
427+ Exception ,
428+ ) as e :
429+ if i < tries :
430+ await asyncio .sleep (0.05 )
431+ continue
432+ return ValueError ("Invalid Azure client: " + str (e ))
433+ return None
0 commit comments