1+ import inspect
12import json
23import typing
34from collections .abc import AsyncIterator , Iterator
45
56import azure .identity
6- from openai import APIResponse , AsyncAzureOpenAI , AzureOpenAI
7+ from openai import AsyncAzureOpenAI , AzureOpenAI
78from openai .types .chat .chat_completion import ChatCompletion
89
910from prompty .tracer import Tracer
1011
1112from .._version import VERSION
1213from ..common import convert_function_tools , convert_output_props
13- from ..core import AsyncPromptyStream , Prompty , PromptyStream
14+ from ..core import AsyncPromptyStream , InputProperty , Prompty , PromptyStream , ToolProperty
1415from ..invoker import Invoker , InvokerFactory
1516
1617
@@ -146,27 +147,48 @@ def _resolve_chat_args(self, data: typing.Any, ignore_thread_content=False) -> d
146147
147148 return args
148149
150+ def _execute_chat_completion (self , client : AzureOpenAI , args : dict , trace ) -> typing .Any :
151+ if "stream" in args and args ["stream" ]:
152+ response = client .chat .completions .create (** args )
153+ else :
154+ raw = client .chat .completions .with_raw_response .create (** args )
155+
156+ response = ChatCompletion .model_validate_json (raw .text )
157+
158+ for k , v in raw .headers .raw :
159+ trace (k .decode ("utf-8" ), v .decode ("utf-8" ))
160+
161+ trace ("request_id" , raw .request_id )
162+ trace ("retries_taken" , raw .retries_taken )
163+
164+ return response
165+
149166 def _create_chat (self , client : AzureOpenAI , data : typing .Any , ignore_thread_content = False ) -> typing .Any :
150167 with Tracer .start ("create" ) as trace :
151168 trace ("type" , "LLM" )
152169 trace ("description" , "Azure OpenAI Client" )
153170 trace ("signature" , "AzureOpenAI.chat.completions.create" )
154171 args = self ._resolve_chat_args (data , ignore_thread_content )
155172 trace ("inputs" , args )
156- if "stream" in args and args ["stream" ]:
157- response = client .chat .completions .create (** args )
158- else :
159- raw = client .chat .completions .with_raw_response .create (** args )
173+ response = self ._execute_chat_completion (client , args , trace )
174+ trace ("result" , response )
175+ return response
160176
161- response = ChatCompletion .model_validate_json (raw .text )
177+ async def _execute_chat_completion_async (self , client : AsyncAzureOpenAI , args : dict , trace ) -> typing .Any :
178+ if "stream" in args and args ["stream" ]:
179+ response = await client .chat .completions .create (** args )
180+ else :
181+ raw = await client .chat .completions .with_raw_response .create (** args )
162182
163- for k , v in raw .headers .raw :
164- trace (k .decode ("utf-8" ), v .decode ("utf-8" ))
183+ response = ChatCompletion .model_validate_json (raw .text )
165184
166- trace ("request_id" , raw .request_id )
167- trace ("retries_taken" , raw .retries_taken )
168- trace ("result" , response )
169- return response
185+ for k , v in raw .headers .raw :
186+ trace (k .decode ("utf-8" ), v .decode ("utf-8" ))
187+
188+ trace ("request_id" , raw .request_id )
189+ trace ("retries_taken" , raw .retries_taken )
190+
191+ return response
170192
171193 async def _create_chat_async (
172194 self , client : AsyncAzureOpenAI , data : typing .Any , ignore_thread_content = False
@@ -178,84 +200,130 @@ async def _create_chat_async(
178200 trace ("signature" , "AzureOpenAIAsync.chat.completions.create" )
179201 args = self ._resolve_chat_args (data , ignore_thread_content )
180202 trace ("inputs" , args )
203+ response = await self ._execute_chat_completion_async (client , args , trace )
204+ trace ("result" , response )
181205
182- response = None
183- if "stream" in args and args ["stream" ]:
184- response = await client .chat .completions .create (** args )
185- else :
186- raw : APIResponse = await client .chat .completions .with_raw_response .create (** args )
187- if raw is not None and raw .text is not None and isinstance (raw .text , str ):
188- response = ChatCompletion .model_validate_json (raw .text )
206+ return response
189207
190- for k , v in raw .headers .raw :
191- trace (k .decode ("utf-8" ), v .decode ("utf-8" ))
208+ def _get_thread (self ) -> InputProperty :
209+ thread = self .prompty .get_input ("thread" )
210+ if thread is None :
211+ raise ValueError ("thread requires thread input" )
192212
193- trace ("request_id" , raw .request_id )
194- trace ("retries_taken" , raw .retries_taken )
195- trace ("result" , response )
213+ return thread
196214
197- return response
215+ def _retrieve_tool (self , tool_name : str ) -> ToolProperty :
216+ tool = self .prompty .get_tool (tool_name )
217+ if tool is None :
218+ raise ValueError (f"Tool { tool_name } does not exist" )
219+
220+ if tool .type != "function" :
221+ raise ValueError (f"Server tool ({ tool_name } ) is currently not supported" )
222+
223+ if tool .value is None :
224+ raise ValueError (f"Tool { tool_name } has not been initialized" )
225+
226+ return tool
198227
199228 def _execute_agent (self , client : AzureOpenAI , data : typing .Any ) -> typing .Any :
200229 with Tracer .start ("create" ) as trace :
201230 trace ("type" , "LLM" )
202231 trace ("description" , "Azure OpenAI Client" )
203-
204232 trace ("signature" , "AzureOpenAI.chat.agent.create" )
233+
205234 trace ("inputs" , data )
206235
207236 response = self ._create_chat (client , data )
208- if isinstance (response , ChatCompletion ):
209- message = response .choices [0 ].message
210- if message .tool_calls :
211- thread = self .prompty .get_input ("thread" )
212- if thread is None :
213- raise ValueError ("thread requires thread input" )
214237
215- thread .value .append (
216- {
217- "role" : "assistant" ,
218- "tool_calls" : [t .model_dump () for t in message .tool_calls ],
219- }
220- )
238+ # execute tool calls if any (until no more tool calls)
239+ while (
240+ isinstance (response , ChatCompletion )
241+ and response .choices [0 ].finish_reason == "tool_calls"
242+ and response .choices [0 ].message .tool_calls is not None
243+ and len (response .choices [0 ].message .tool_calls ) > 0
244+ ):
245+
246+ tool_calls = response .choices [0 ].message .tool_calls
247+ thread = self ._get_thread ()
248+ thread .value .append (
249+ {
250+ "role" : "assistant" ,
251+ "tool_calls" : [t .model_dump () for t in tool_calls ],
252+ }
253+ )
221254
222- for tool_call in message .tool_calls :
223- tool = self .prompty .get_tool (tool_call .function .name )
224- if tool is None :
225- raise ValueError (f"Tool { tool_call .function .name } does not exist" )
255+ for tool_call in tool_calls :
256+ tool = self ._retrieve_tool (tool_call .function .name )
257+ function_args = json .loads (tool_call .function .arguments )
226258
227- function_args = json .loads (tool_call .function .arguments )
259+ if inspect .iscoroutinefunction (tool .value ):
260+ raise ValueError ("Cannot execute async tool in sync mode" )
228261
229- if tool .value is None :
230- raise ValueError (f"Tool { tool_call .function .name } does not have a value" )
262+ r = tool .value (** function_args )
231263
232- r = tool .value (** function_args )
264+ thread .value .append (
265+ {
266+ "role" : "tool" ,
267+ "tool_call_id" : tool_call .id ,
268+ "name" : tool_call .function .name ,
269+ "content" : r ,
270+ }
271+ )
233272
234- thread .value .append (
235- {
236- "role" : "tool" ,
237- "tool_call_id" : tool_call .id ,
238- "name" : tool_call .function .name ,
239- "content" : r ,
240- }
241- )
242- else :
243- trace ("result" , response )
244- return response
273+ response = self ._create_chat (client , data , True )
245274
246- response = self ._create_chat (client , data , True )
247275 trace ("result" , response )
248-
249276 return response
250277
251278 async def _execute_agent_async (self , client : AsyncAzureOpenAI , data : typing .Any ) -> typing .Any :
252279 with Tracer .start ("create" ) as trace :
253280 trace ("type" , "LLM" )
254281 trace ("description" , "Azure OpenAI Client" )
255- trace ("signature" , "AzureOpenAI.chat.agent.create" )
256- args = self ._resolve_chat_args (data )
257- trace ("inputs" , args )
258- response = 5
282+ trace ("signature" , "AzureOpenAIAsync.chat.agent.create" )
283+
284+ trace ("inputs" , data )
285+
286+ response = await self ._create_chat_async (client , data )
287+
288+ # execute tool calls if any (until no more tool calls)
289+ while (
290+ isinstance (response , ChatCompletion )
291+ and response .choices [0 ].finish_reason == "tool_calls"
292+ and response .choices [0 ].message .tool_calls is not None
293+ and len (response .choices [0 ].message .tool_calls ) > 0
294+ ):
295+
296+ tool_calls = response .choices [0 ].message .tool_calls
297+ thread = self ._get_thread ()
298+ thread .value .append (
299+ {
300+ "role" : "assistant" ,
301+ "tool_calls" : [t .model_dump () for t in tool_calls ],
302+ }
303+ )
304+
305+ for tool_call in tool_calls :
306+ tool = self ._retrieve_tool (tool_call .function .name )
307+ function_args = json .loads (tool_call .function .arguments )
308+
309+ if inspect .iscoroutinefunction (tool .value ):
310+ # if the tool is async, we need to await it
311+ r = await tool .value (** function_args )
312+ else :
313+ # if the tool is not async, we can call it directly
314+ r = tool .value (** function_args )
315+
316+ thread .value .append (
317+ {
318+ "role" : "tool" ,
319+ "tool_call_id" : tool_call .id ,
320+ "name" : tool_call .function .name ,
321+ "content" : r ,
322+ }
323+ )
324+
325+ response = await self ._create_chat_async (client , data , True )
326+
259327 trace ("result" , response )
260328 return response
261329
@@ -360,7 +428,7 @@ async def _create_image_async(self, client: AsyncAzureOpenAI, data: typing.Any)
360428
361429 return response
362430
363- def invoke (self , data : typing .Any ) -> typing .Union [ str , PromptyStream ] :
431+ def invoke (self , data : typing .Any ) -> typing .Any :
364432 """Invoke the Azure OpenAI API
365433
366434 Parameters
@@ -379,8 +447,8 @@ def invoke(self, data: typing.Any) -> typing.Union[str, PromptyStream]:
379447 r = None
380448 if self .api == "chat" :
381449 r = self ._create_chat (client , data )
382- # elif self.api == "agent":
383- # r = self._execute_agent(client, data)
450+ elif self .api == "agent" :
451+ r = self ._execute_agent (client , data )
384452 elif self .api == "completion" :
385453 r = self ._create_completion (client , data )
386454 elif self .api == "embedding" :
@@ -395,12 +463,10 @@ def invoke(self, data: typing.Any) -> typing.Union[str, PromptyStream]:
395463 return PromptyStream ("AzureOpenAIExecutor" , r )
396464 else :
397465 return PromptyStream ("AzureOpenAIExecutor" , r )
398- elif isinstance (r , str ):
399- return r
400466 else :
401- raise ValueError ( f"Unexpected response type: { type ( r ) } " )
467+ return r
402468
403- async def invoke_async (self , data : str ) -> typing .Union [ str , AsyncPromptyStream ] :
469+ async def invoke_async (self , data : str ) -> typing .Any :
404470 """Invoke the Prompty Chat Parser (Async)
405471
406472 Parameters
@@ -418,8 +484,8 @@ async def invoke_async(self, data: str) -> typing.Union[str, AsyncPromptyStream]
418484 r = None
419485 if self .api == "chat" :
420486 r = await self ._create_chat_async (client , data )
421- # elif self.api == "agent":
422- # r = await self._execute_agent_async(client, data)
487+ elif self .api == "agent" :
488+ r = await self ._execute_agent_async (client , data )
423489 elif self .api == "completion" :
424490 r = await self ._create_completion_async (client , data )
425491 elif self .api == "embedding" :
@@ -434,7 +500,5 @@ async def invoke_async(self, data: str) -> typing.Union[str, AsyncPromptyStream]
434500 return AsyncPromptyStream ("AzureOpenAIExecutorAsync" , r )
435501 else :
436502 return AsyncPromptyStream ("AzureOpenAIExecutorAsync" , r )
437- elif isinstance (r , str ):
438- return r
439503 else :
440- raise ValueError ( f"Unexpected response type: { type ( r ) } " )
504+ return r
0 commit comments