Skip to content

Commit 3523bbc

Browse files
committed
finalized structured ouputs and agentic execution
1 parent 2f37eae commit 3523bbc

12 files changed

+745
-94
lines changed

runtime/prompty/prompty/azure/executor.py

Lines changed: 140 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1+
import inspect
12
import json
23
import typing
34
from collections.abc import AsyncIterator, Iterator
45

56
import azure.identity
6-
from openai import APIResponse, AsyncAzureOpenAI, AzureOpenAI
7+
from openai import AsyncAzureOpenAI, AzureOpenAI
78
from openai.types.chat.chat_completion import ChatCompletion
89

910
from prompty.tracer import Tracer
1011

1112
from .._version import VERSION
1213
from ..common import convert_function_tools, convert_output_props
13-
from ..core import AsyncPromptyStream, Prompty, PromptyStream
14+
from ..core import AsyncPromptyStream, InputProperty, Prompty, PromptyStream, ToolProperty
1415
from ..invoker import Invoker, InvokerFactory
1516

1617

@@ -146,27 +147,48 @@ def _resolve_chat_args(self, data: typing.Any, ignore_thread_content=False) -> d
146147

147148
return args
148149

150+
def _execute_chat_completion(self, client: AzureOpenAI, args: dict, trace) -> typing.Any:
151+
if "stream" in args and args["stream"]:
152+
response = client.chat.completions.create(**args)
153+
else:
154+
raw = client.chat.completions.with_raw_response.create(**args)
155+
156+
response = ChatCompletion.model_validate_json(raw.text)
157+
158+
for k, v in raw.headers.raw:
159+
trace(k.decode("utf-8"), v.decode("utf-8"))
160+
161+
trace("request_id", raw.request_id)
162+
trace("retries_taken", raw.retries_taken)
163+
164+
return response
165+
149166
def _create_chat(self, client: AzureOpenAI, data: typing.Any, ignore_thread_content=False) -> typing.Any:
150167
with Tracer.start("create") as trace:
151168
trace("type", "LLM")
152169
trace("description", "Azure OpenAI Client")
153170
trace("signature", "AzureOpenAI.chat.completions.create")
154171
args = self._resolve_chat_args(data, ignore_thread_content)
155172
trace("inputs", args)
156-
if "stream" in args and args["stream"]:
157-
response = client.chat.completions.create(**args)
158-
else:
159-
raw = client.chat.completions.with_raw_response.create(**args)
173+
response = self._execute_chat_completion(client, args, trace)
174+
trace("result", response)
175+
return response
160176

161-
response = ChatCompletion.model_validate_json(raw.text)
177+
async def _execute_chat_completion_async(self, client: AsyncAzureOpenAI, args: dict, trace) -> typing.Any:
178+
if "stream" in args and args["stream"]:
179+
response = await client.chat.completions.create(**args)
180+
else:
181+
raw = await client.chat.completions.with_raw_response.create(**args)
162182

163-
for k, v in raw.headers.raw:
164-
trace(k.decode("utf-8"), v.decode("utf-8"))
183+
response = ChatCompletion.model_validate_json(raw.text)
165184

166-
trace("request_id", raw.request_id)
167-
trace("retries_taken", raw.retries_taken)
168-
trace("result", response)
169-
return response
185+
for k, v in raw.headers.raw:
186+
trace(k.decode("utf-8"), v.decode("utf-8"))
187+
188+
trace("request_id", raw.request_id)
189+
trace("retries_taken", raw.retries_taken)
190+
191+
return response
170192

171193
async def _create_chat_async(
172194
self, client: AsyncAzureOpenAI, data: typing.Any, ignore_thread_content=False
@@ -178,84 +200,130 @@ async def _create_chat_async(
178200
trace("signature", "AzureOpenAIAsync.chat.completions.create")
179201
args = self._resolve_chat_args(data, ignore_thread_content)
180202
trace("inputs", args)
203+
response = await self._execute_chat_completion_async(client, args, trace)
204+
trace("result", response)
181205

182-
response = None
183-
if "stream" in args and args["stream"]:
184-
response = await client.chat.completions.create(**args)
185-
else:
186-
raw: APIResponse = await client.chat.completions.with_raw_response.create(**args)
187-
if raw is not None and raw.text is not None and isinstance(raw.text, str):
188-
response = ChatCompletion.model_validate_json(raw.text)
206+
return response
189207

190-
for k, v in raw.headers.raw:
191-
trace(k.decode("utf-8"), v.decode("utf-8"))
208+
def _get_thread(self) -> InputProperty:
209+
thread = self.prompty.get_input("thread")
210+
if thread is None:
211+
raise ValueError("thread requires thread input")
192212

193-
trace("request_id", raw.request_id)
194-
trace("retries_taken", raw.retries_taken)
195-
trace("result", response)
213+
return thread
196214

197-
return response
215+
def _retrieve_tool(self, tool_name: str) -> ToolProperty:
216+
tool = self.prompty.get_tool(tool_name)
217+
if tool is None:
218+
raise ValueError(f"Tool {tool_name} does not exist")
219+
220+
if tool.type != "function":
221+
raise ValueError(f"Server tool ({tool_name}) is currently not supported")
222+
223+
if tool.value is None:
224+
raise ValueError(f"Tool {tool_name} has not been initialized")
225+
226+
return tool
198227

199228
def _execute_agent(self, client: AzureOpenAI, data: typing.Any) -> typing.Any:
200229
with Tracer.start("create") as trace:
201230
trace("type", "LLM")
202231
trace("description", "Azure OpenAI Client")
203-
204232
trace("signature", "AzureOpenAI.chat.agent.create")
233+
205234
trace("inputs", data)
206235

207236
response = self._create_chat(client, data)
208-
if isinstance(response, ChatCompletion):
209-
message = response.choices[0].message
210-
if message.tool_calls:
211-
thread = self.prompty.get_input("thread")
212-
if thread is None:
213-
raise ValueError("thread requires thread input")
214237

215-
thread.value.append(
216-
{
217-
"role": "assistant",
218-
"tool_calls": [t.model_dump() for t in message.tool_calls],
219-
}
220-
)
238+
# execute tool calls if any (until no more tool calls)
239+
while (
240+
isinstance(response, ChatCompletion)
241+
and response.choices[0].finish_reason == "tool_calls"
242+
and response.choices[0].message.tool_calls is not None
243+
and len(response.choices[0].message.tool_calls) > 0
244+
):
245+
246+
tool_calls = response.choices[0].message.tool_calls
247+
thread = self._get_thread()
248+
thread.value.append(
249+
{
250+
"role": "assistant",
251+
"tool_calls": [t.model_dump() for t in tool_calls],
252+
}
253+
)
221254

222-
for tool_call in message.tool_calls:
223-
tool = self.prompty.get_tool(tool_call.function.name)
224-
if tool is None:
225-
raise ValueError(f"Tool {tool_call.function.name} does not exist")
255+
for tool_call in tool_calls:
256+
tool = self._retrieve_tool(tool_call.function.name)
257+
function_args = json.loads(tool_call.function.arguments)
226258

227-
function_args = json.loads(tool_call.function.arguments)
259+
if inspect.iscoroutinefunction(tool.value):
260+
raise ValueError("Cannot execute async tool in sync mode")
228261

229-
if tool.value is None:
230-
raise ValueError(f"Tool {tool_call.function.name} does not have a value")
262+
r = tool.value(**function_args)
231263

232-
r = tool.value(**function_args)
264+
thread.value.append(
265+
{
266+
"role": "tool",
267+
"tool_call_id": tool_call.id,
268+
"name": tool_call.function.name,
269+
"content": r,
270+
}
271+
)
233272

234-
thread.value.append(
235-
{
236-
"role": "tool",
237-
"tool_call_id": tool_call.id,
238-
"name": tool_call.function.name,
239-
"content": r,
240-
}
241-
)
242-
else:
243-
trace("result", response)
244-
return response
273+
response = self._create_chat(client, data, True)
245274

246-
response = self._create_chat(client, data, True)
247275
trace("result", response)
248-
249276
return response
250277

251278
async def _execute_agent_async(self, client: AsyncAzureOpenAI, data: typing.Any) -> typing.Any:
252279
with Tracer.start("create") as trace:
253280
trace("type", "LLM")
254281
trace("description", "Azure OpenAI Client")
255-
trace("signature", "AzureOpenAI.chat.agent.create")
256-
args = self._resolve_chat_args(data)
257-
trace("inputs", args)
258-
response = 5
282+
trace("signature", "AzureOpenAIAsync.chat.agent.create")
283+
284+
trace("inputs", data)
285+
286+
response = await self._create_chat_async(client, data)
287+
288+
# execute tool calls if any (until no more tool calls)
289+
while (
290+
isinstance(response, ChatCompletion)
291+
and response.choices[0].finish_reason == "tool_calls"
292+
and response.choices[0].message.tool_calls is not None
293+
and len(response.choices[0].message.tool_calls) > 0
294+
):
295+
296+
tool_calls = response.choices[0].message.tool_calls
297+
thread = self._get_thread()
298+
thread.value.append(
299+
{
300+
"role": "assistant",
301+
"tool_calls": [t.model_dump() for t in tool_calls],
302+
}
303+
)
304+
305+
for tool_call in tool_calls:
306+
tool = self._retrieve_tool(tool_call.function.name)
307+
function_args = json.loads(tool_call.function.arguments)
308+
309+
if inspect.iscoroutinefunction(tool.value):
310+
# if the tool is async, we need to await it
311+
r = await tool.value(**function_args)
312+
else:
313+
# if the tool is not async, we can call it directly
314+
r = tool.value(**function_args)
315+
316+
thread.value.append(
317+
{
318+
"role": "tool",
319+
"tool_call_id": tool_call.id,
320+
"name": tool_call.function.name,
321+
"content": r,
322+
}
323+
)
324+
325+
response = await self._create_chat_async(client, data, True)
326+
259327
trace("result", response)
260328
return response
261329

@@ -360,7 +428,7 @@ async def _create_image_async(self, client: AsyncAzureOpenAI, data: typing.Any)
360428

361429
return response
362430

363-
def invoke(self, data: typing.Any) -> typing.Union[str, PromptyStream]:
431+
def invoke(self, data: typing.Any) -> typing.Any:
364432
"""Invoke the Azure OpenAI API
365433
366434
Parameters
@@ -379,8 +447,8 @@ def invoke(self, data: typing.Any) -> typing.Union[str, PromptyStream]:
379447
r = None
380448
if self.api == "chat":
381449
r = self._create_chat(client, data)
382-
#elif self.api == "agent":
383-
# r = self._execute_agent(client, data)
450+
elif self.api == "agent":
451+
r = self._execute_agent(client, data)
384452
elif self.api == "completion":
385453
r = self._create_completion(client, data)
386454
elif self.api == "embedding":
@@ -395,12 +463,10 @@ def invoke(self, data: typing.Any) -> typing.Union[str, PromptyStream]:
395463
return PromptyStream("AzureOpenAIExecutor", r)
396464
else:
397465
return PromptyStream("AzureOpenAIExecutor", r)
398-
elif isinstance(r, str):
399-
return r
400466
else:
401-
raise ValueError(f"Unexpected response type: {type(r)}")
467+
return r
402468

403-
async def invoke_async(self, data: str) -> typing.Union[str, AsyncPromptyStream]:
469+
async def invoke_async(self, data: str) -> typing.Any:
404470
"""Invoke the Prompty Chat Parser (Async)
405471
406472
Parameters
@@ -418,8 +484,8 @@ async def invoke_async(self, data: str) -> typing.Union[str, AsyncPromptyStream]
418484
r = None
419485
if self.api == "chat":
420486
r = await self._create_chat_async(client, data)
421-
#elif self.api == "agent":
422-
# r = await self._execute_agent_async(client, data)
487+
elif self.api == "agent":
488+
r = await self._execute_agent_async(client, data)
423489
elif self.api == "completion":
424490
r = await self._create_completion_async(client, data)
425491
elif self.api == "embedding":
@@ -434,7 +500,5 @@ async def invoke_async(self, data: str) -> typing.Union[str, AsyncPromptyStream]
434500
return AsyncPromptyStream("AzureOpenAIExecutorAsync", r)
435501
else:
436502
return AsyncPromptyStream("AzureOpenAIExecutorAsync", r)
437-
elif isinstance(r, str):
438-
return r
439503
else:
440-
raise ValueError(f"Unexpected response type: {type(r)}")
504+
return r

runtime/prompty/prompty/invoker.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def process(self, data: typing.Any) -> typing.Any:
270270
"""
271271
pass
272272

273+
@trace
273274
def run(self, data: typing.Any) -> typing.Any:
274275
"""Method to run the invoker
275276
@@ -287,6 +288,7 @@ def run(self, data: typing.Any) -> typing.Any:
287288
parsed = self.invoke(data)
288289
return self.process(parsed)
289290

291+
@trace
290292
async def run_async(self, data: typing.Any) -> typing.Any:
291293
"""Method to run the invoker asynchronously
292294
@@ -304,6 +306,10 @@ async def run_async(self, data: typing.Any) -> typing.Any:
304306
return self.process(parsed)
305307

306308

309+
310+
InvokerTypes = Literal["renderer", "parser", "executor", "processor"]
311+
312+
307313
class InvokerFactory:
308314
"""Factory class for Invoker"""
309315

@@ -328,6 +334,7 @@ def add_executor(cls, name: str, invoker: type[Invoker]) -> None:
328334
def add_processor(cls, name: str, invoker: type[Invoker]) -> None:
329335
cls._processors[name] = invoker
330336

337+
331338
@classmethod
332339
def register_renderer(cls, name: str) -> Callable:
333340

@@ -364,10 +371,11 @@ def inner_wrapper(wrapped_class: type[Invoker]) -> type[Invoker]:
364371

365372
return inner_wrapper
366373

374+
367375
@classmethod
368376
def _get_name(
369377
cls,
370-
type: Literal["renderer", "parser", "executor", "processor"],
378+
type: InvokerTypes,
371379
prompty: Prompty,
372380
) -> str:
373381
if type == "renderer":
@@ -384,7 +392,7 @@ def _get_name(
384392
@classmethod
385393
def _get_invoker(
386394
cls,
387-
type: Literal["renderer", "parser", "executor", "processor"],
395+
type: InvokerTypes,
388396
prompty: Prompty,
389397
) -> Invoker:
390398
if type == "renderer":
@@ -421,7 +429,7 @@ def _get_invoker(
421429
@classmethod
422430
def run(
423431
cls,
424-
type: Literal["renderer", "parser", "executor", "processor"],
432+
type: InvokerTypes,
425433
prompty: Prompty,
426434
data: typing.Any,
427435
default: typing.Any = None,
@@ -439,7 +447,7 @@ def run(
439447
@classmethod
440448
async def run_async(
441449
cls,
442-
type: Literal["renderer", "parser", "executor", "processor"],
450+
type: InvokerTypes,
443451
prompty: Prompty,
444452
data: typing.Any,
445453
default: typing.Any = None,

0 commit comments

Comments
 (0)