Skip to content

Commit a29e8b3

Browse files
LLM Attribute Calculation Foundry (#293)
* provider * Adds config test for foundry * restructure file * client cache * asynch client valid --------- Co-authored-by: JWittmeyer <[email protected]>
1 parent 539ebd1 commit a29e8b3

File tree

2 files changed

+170
-85
lines changed

2 files changed

+170
-85
lines changed

controller/attribute/llm_response_tmpl.py

Lines changed: 142 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import Any, Optional, Union, List, Dict
44
from enum import Enum
55
import asyncio
6-
76
from openai import OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI
87
from openai import (
98
AuthenticationError,
@@ -14,12 +13,21 @@
1413
BadRequestError,
1514
)
1615
from openai.types.chat import ChatCompletion
16+
from azure.ai.inference import ChatCompletionsClient
17+
from azure.ai.inference.aio import ChatCompletionsClient as AsyncChatCompletionsClient
18+
from azure.core.credentials import AzureKeyCredential
19+
from azure.core.exceptions import (
20+
HttpResponseError,
21+
ServiceRequestError,
22+
ClientAuthenticationError,
23+
)
1724

1825

1926
class LLMProvider_A2VYBG(Enum):
2027
OPEN_AI = "Open AI"
2128
OPEN_SOURCE = "Open-Source"
2229
AZURE = "Azure"
30+
AZURE_FOUNDRY = "Azure Foundry"
2331

2432

2533
# OpenAI migration guides
@@ -42,7 +50,6 @@ class LLMProvider_A2VYBG(Enum):
4250
CACHE_ACCESS_LINK_A2VYBG = "@@CACHE_ACCESS_LINK@@"
4351
CACHE_FILE_UPLOAD_LINK_A2VYBG = "@@CACHE_FILE_UPLOAD_LINK@@"
4452
LLM_KWARGS_A2VYBG = {
45-
"response_format": {"type": "json_object"},
4653
"stream": False,
4754
# fmt:off
4855
"stop": json.loads('@@STOP_SEQUENCE@@'),
@@ -67,37 +74,7 @@ class LLMProvider_A2VYBG(Enum):
6774
# azure_endpoint = api_base (before 1.0) - basically the link to the api
6875

6976

70-
def test_client_model_2c6ecfb1_9bce_4e89_80c8_cbc4e3fca9e5(
71-
client: Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI], model: str
72-
):
73-
if __is_client_valid_ex_8840b3a8_92d2_4526_b054_3b83c5cccb5c(client) is not None:
74-
print(
75-
"Error: Invalid OpenAI client config (api_key, api_version or endpoint)",
76-
flush=True,
77-
)
78-
return False
79-
80-
try:
81-
client.chat.completions.create(
82-
model=model,
83-
messages=[
84-
{
85-
"role": "user",
86-
"content": "A",
87-
}
88-
],
89-
stream=False,
90-
temperature=1,
91-
max_tokens=1,
92-
)
93-
except Exception as e:
94-
print("Error: Test chat completion failed", flush=True)
95-
print(e, flush=True)
96-
return False
97-
return True
98-
99-
100-
def get_client_8e8a360e_3f7f_4cf9_ba80_8cb239e897d2(
77+
def get_client_openai_8e8a360e_3f7f_4cf9_ba80_8cb239e897d2(
10178
use_async: bool,
10279
api_key: str,
10380
azure_endpoint: Optional[str] = None,
@@ -117,8 +94,10 @@ def get_client_8e8a360e_3f7f_4cf9_ba80_8cb239e897d2(
11794
use_cache = MAX_CACHED_CLIENTS_A2VYBG != 0 and not prevent_cached_client
11895
if use_cache and config in CLIENT_LOOKUP_A2VYBG:
11996
if check_valid:
120-
exception = __is_client_valid_ex_8840b3a8_92d2_4526_b054_3b83c5cccb5c(
121-
CLIENT_LOOKUP_A2VYBG[config][0]
97+
exception = (
98+
__is_client_valid_ex_openai_8840b3a8_92d2_4526_b054_3b83c5cccb5c(
99+
CLIENT_LOOKUP_A2VYBG[config][0]
100+
)
122101
)
123102
if exception is not None:
124103
raise exception
@@ -136,14 +115,14 @@ def get_client_8e8a360e_3f7f_4cf9_ba80_8cb239e897d2(
136115
client.close()
137116
CLIENT_LOOKUP_A2VYBG = dict(tmp)
138117

139-
client = __create_client_bf47529a_75f7_498b_a091_4e7d52d35b6b(
118+
client = __create_client_openai_bf47529a_75f7_498b_a091_4e7d52d35b6b(
140119
use_async, api_key, azure_endpoint, api_version
141120
)
142121

143122
# test client with api key
144123
if check_valid:
145-
exception = __is_client_valid_ex_8840b3a8_92d2_4526_b054_3b83c5cccb5c(
146-
client
124+
exception = (
125+
__is_client_valid_ex_openai_8840b3a8_92d2_4526_b054_3b83c5cccb5c(client)
147126
)
148127
if exception is not None:
149128
raise exception
@@ -153,7 +132,7 @@ def get_client_8e8a360e_3f7f_4cf9_ba80_8cb239e897d2(
153132
return client
154133

155134

156-
def __create_client_bf47529a_75f7_498b_a091_4e7d52d35b6b(
135+
def __create_client_openai_bf47529a_75f7_498b_a091_4e7d52d35b6b(
157136
use_async: bool,
158137
api_key: str,
159138
azure_endpoint: Optional[str] = None,
@@ -185,7 +164,7 @@ def __create_client_bf47529a_75f7_498b_a091_4e7d52d35b6b(
185164
return client
186165

187166

188-
def __is_client_valid_ex_8840b3a8_92d2_4526_b054_3b83c5cccb5c(
167+
def __is_client_valid_ex_openai_8840b3a8_92d2_4526_b054_3b83c5cccb5c(
189168
client: Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI], tries: int = 3
190169
) -> Union[AuthenticationError, Exception, None]:
191170
i = 0
@@ -251,36 +230,6 @@ def convert_to_string(data):
251230
return str(data)
252231

253232

254-
# all work similar but use different classes etc.
255-
# note that kwargs is just passed to the openai client so adding unknown kwargs will result in issues
256-
# named parameter are NOT considered kwargs, only unknown parameters are kwargs
257-
def get_chat_completion_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c(
258-
model: str,
259-
messages: List[Dict[str, str]],
260-
api_key: str,
261-
azure_endpoint: Optional[str] = None,
262-
api_version: Optional[str] = None,
263-
close_after: bool = False,
264-
**kwargs,
265-
) -> ChatCompletion:
266-
client = get_client_8e8a360e_3f7f_4cf9_ba80_8cb239e897d2(
267-
use_async=False,
268-
api_key=api_key,
269-
azure_endpoint=azure_endpoint,
270-
api_version=api_version,
271-
prevent_cached_client=close_after,
272-
)
273-
completion = client.chat.completions.create(
274-
model=model,
275-
messages=messages,
276-
**kwargs,
277-
)
278-
if close_after:
279-
client.close()
280-
281-
return completion
282-
283-
284233
async def get_chat_completion_async_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c(
285234
model: str,
286235
messages: List[Dict[str, str]],
@@ -290,21 +239,32 @@ async def get_chat_completion_async_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c(
290239
close_after: bool = False,
291240
**kwargs,
292241
) -> ChatCompletion:
293-
client = get_client_8e8a360e_3f7f_4cf9_ba80_8cb239e897d2(
294-
use_async=True,
295-
api_key=api_key,
296-
azure_endpoint=azure_endpoint,
297-
api_version=api_version,
298-
prevent_cached_client=close_after,
299-
)
300-
completion = await client.chat.completions.create(
301-
model=model,
302-
messages=messages,
303-
**kwargs,
304-
)
242+
completion = None
243+
if CLIENT_TYPE_A2VYBG == LLMProvider_A2VYBG.AZURE_FOUNDRY.value:
244+
client = await get_client_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c(
245+
use_async=True,
246+
api_key=api_key,
247+
azure_endpoint=azure_endpoint,
248+
)
249+
completion = await client.complete(
250+
messages=messages, response_format="json_object", **kwargs
251+
)
252+
else:
253+
client = get_client_openai_8e8a360e_3f7f_4cf9_ba80_8cb239e897d2(
254+
use_async=True,
255+
api_key=api_key,
256+
azure_endpoint=azure_endpoint,
257+
api_version=api_version,
258+
prevent_cached_client=close_after,
259+
)
260+
completion = await client.chat.completions.create(
261+
model=model,
262+
messages=messages,
263+
response_format={"type": "json_object"},
264+
**kwargs,
265+
)
305266
if close_after:
306267
await client.close()
307-
308268
return completion
309269

310270

@@ -372,3 +332,102 @@ async def get_llm_response(record: dict, cached_records: dict):
372332
print(m, flush=True)
373333
cached_records[curr_running_id] = {"result": m}
374334
return {"result": m}
335+
336+
337+
# ------------------ AZURE FOUNDRY------------------
338+
339+
340+
async def get_client_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c(
341+
use_async: bool,
342+
api_key: str,
343+
azure_endpoint: Optional[str] = None,
344+
check_valid: bool = True,
345+
prevent_cached_client: bool = True,
346+
) -> Union[ChatCompletionsClient, AsyncChatCompletionsClient]:
347+
348+
global CLIENT_LOOKUP_A2VYBG
349+
350+
if CLIENT_TYPE_A2VYBG == LLMProvider_A2VYBG.AZURE_FOUNDRY.value and (
351+
azure_endpoint is None
352+
):
353+
raise ValueError("azure_endpoint must be set for Azure Foundry")
354+
355+
# tuples can be used as dict keys, primitive datatype comparison works flawless, caution with objects though!
356+
config = (CLIENT_TYPE_A2VYBG, use_async, api_key, azure_endpoint)
357+
use_cache = MAX_CACHED_CLIENTS_A2VYBG != 0 and not prevent_cached_client
358+
if use_cache and config in CLIENT_LOOKUP_A2VYBG:
359+
if check_valid:
360+
exception = await __is_client_valid_ex_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c(
361+
CLIENT_LOOKUP_A2VYBG[config][0]
362+
)
363+
if exception is not None:
364+
raise exception
365+
366+
CLIENT_LOOKUP_A2VYBG[config] = (CLIENT_LOOKUP_A2VYBG[config][0], time.time())
367+
368+
return CLIENT_LOOKUP_A2VYBG[config][0]
369+
370+
else:
371+
if use_cache and len(CLIENT_LOOKUP_A2VYBG) >= MAX_CACHED_CLIENTS_A2VYBG:
372+
# remove oldest client
373+
tmp = sorted(
374+
CLIENT_LOOKUP_A2VYBG.items(), key=lambda x: x[1][1], reverse=True
375+
)
376+
(client, _) = tmp.pop()
377+
client.close()
378+
CLIENT_LOOKUP_A2VYBG = dict(tmp)
379+
380+
client = __create_client_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c(
381+
use_async=use_async, api_key=api_key, azure_endpoint=azure_endpoint
382+
)
383+
384+
# test client with api key
385+
if check_valid:
386+
exception = await __is_client_valid_ex_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c(
387+
client
388+
)
389+
if exception is not None:
390+
raise exception
391+
392+
if use_cache:
393+
CLIENT_LOOKUP_A2VYBG[config] = (client, time.time())
394+
return client
395+
396+
397+
def __create_client_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c(
398+
use_async: bool,
399+
api_key: str,
400+
azure_endpoint: Optional[str] = None,
401+
) -> Union[ChatCompletionsClient, AsyncChatCompletionsClient]:
402+
403+
if use_async:
404+
client = AsyncChatCompletionsClient(
405+
endpoint=azure_endpoint, credential=AzureKeyCredential(api_key)
406+
)
407+
else:
408+
client = ChatCompletionsClient(
409+
endpoint=azure_endpoint, credential=AzureKeyCredential(api_key)
410+
)
411+
412+
return client
413+
414+
415+
async def __is_client_valid_ex_azure_foundry_4a90ecec_fc72_45af_ba0d_ae9a2dc4674c(
416+
client: AsyncChatCompletionsClient,
417+
tries: int = 3,
418+
) -> Union[Exception, None]:
419+
for i in range(tries + 1):
420+
try:
421+
await client.get_model_info()
422+
return None
423+
except (
424+
HttpResponseError,
425+
ServiceRequestError,
426+
ClientAuthenticationError,
427+
Exception,
428+
) as e:
429+
if i < tries:
430+
await asyncio.sleep(0.05)
431+
continue
432+
return ValueError("Invalid Azure client: " + str(e))
433+
return None

controller/attribute/util.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def test_openai_llm_connection(api_key: str, model: str):
9292
"messages": [
9393
{"role": "user", "content": [{"type": "text", "text": "only say 'hello'"}]},
9494
],
95-
"max_tokens": 20,
95+
"max_tokens": 5,
9696
}
9797

9898
response = requests.post(
@@ -102,6 +102,27 @@ def test_openai_llm_connection(api_key: str, model: str):
102102
return response.json()["choices"][0]["message"]["content"]
103103

104104

105+
def test_azure_foundry_llm_connection(api_key: str, base_endpoint: str):
106+
# more here: https://learn.microsoft.com/en-us/rest/api/aifoundry/modelinference/
107+
base_endpoint = base_endpoint.rstrip("/")
108+
final_endpoint = f"{base_endpoint}/chat/completions"
109+
headers = {
110+
"Content-Type": "application/json",
111+
"Authorization": f"Bearer {api_key}",
112+
}
113+
114+
payload = {
115+
"messages": [
116+
{"role": "user", "content": [{"type": "text", "text": "only say 'hello'"}]},
117+
],
118+
"max_tokens": 5,
119+
}
120+
121+
response = requests.post(final_endpoint, headers=headers, json=payload)
122+
response.raise_for_status()
123+
return response.json()["choices"][0]["message"]["content"]
124+
125+
105126
def test_azure_llm_connection(
106127
api_key: str, base_endpoint: str, api_version: str, model: str
107128
):
@@ -129,7 +150,7 @@ def test_azure_llm_connection(
129150
"messages": [
130151
{"role": "user", "content": [{"type": "text", "text": "only say 'hello'"}]},
131152
],
132-
"max_tokens": 20,
153+
"max_tokens": 5,
133154
}
134155

135156
response = requests.post(final_endpoint, headers=headers, json=payload)
@@ -177,6 +198,11 @@ def validate_llm_config(llm_config: Dict[str, Any]):
177198
base_endpoint=llm_config["apiBase"],
178199
api_version=llm_config["apiVersion"],
179200
)
201+
elif llm_config["llmIdentifier"] == enums.LLMProvider.AZURE_FOUNDRY.value:
202+
test_azure_foundry_llm_connection(
203+
api_key=llm_config["apiKey"],
204+
base_endpoint=llm_config["apiBase"],
205+
)
180206
else:
181207
raise LlmResponseError(
182208
"LLM Identifier must be either Open AI or Azure, got: "

0 commit comments

Comments
 (0)