Skip to content

Commit 8decc7a

Browse files
committed
Let tool functions yield custom events
1 parent 3eaa11e commit 8decc7a

File tree

10 files changed

+189
-22
lines changed

10 files changed

+189
-22
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from dataclasses import field, replace
1313
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast
1414

15+
import anyio
1516
from opentelemetry.trace import Tracer
1617
from typing_extensions import TypeVar, assert_never
1718

@@ -643,32 +644,47 @@ async def _handle_tool_calls(
643644
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
644645
tool_calls: list[_messages.ToolCallPart],
645646
) -> AsyncIterator[_messages.HandleResponseEvent]:
647+
send_stream, receive_stream = anyio.create_memory_object_stream[_messages.HandleResponseEvent]()
648+
646649
run_context = build_run_context(ctx)
647650

648-
# This will raise errors for any tool name conflicts
649-
ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)
651+
async def _run_tool_calls() -> tuple[list[_messages.ModelRequestPart], result.FinalResult[NodeRunEndT] | None]:
652+
async with send_stream:
653+
tool_run_context = replace(run_context, event_stream=send_stream)
650654

651-
output_parts: list[_messages.ModelRequestPart] = []
652-
output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1)
655+
# This will raise errors for any tool name conflicts
656+
ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(tool_run_context)
653657

654-
async for event in process_tool_calls(
655-
tool_manager=ctx.deps.tool_manager,
656-
tool_calls=tool_calls,
657-
tool_call_results=self.tool_call_results,
658-
final_result=None,
659-
ctx=ctx,
660-
output_parts=output_parts,
661-
output_final_result=output_final_result,
662-
):
663-
yield event
658+
output_parts: list[_messages.ModelRequestPart] = []
659+
output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1)
664660

665-
if output_final_result:
666-
final_result = output_final_result[0]
667-
self._next_node = self._handle_final_result(ctx, final_result, output_parts)
661+
async for event in process_tool_calls(
662+
tool_manager=ctx.deps.tool_manager,
663+
tool_calls=tool_calls,
664+
tool_call_results=self.tool_call_results,
665+
final_result=None,
666+
ctx=ctx,
667+
output_parts=output_parts,
668+
output_final_result=output_final_result,
669+
):
670+
await send_stream.send(event)
671+
672+
return output_parts, output_final_result[0] if output_final_result else None
673+
674+
task = asyncio.create_task(_run_tool_calls())
675+
676+
async with receive_stream:
677+
async for message in receive_stream:
678+
yield message
679+
680+
parts, final_result = await task
681+
682+
if final_result:
683+
self._next_node = self._handle_final_result(ctx, final_result, parts)
668684
else:
669685
instructions = await ctx.deps.get_instructions(run_context)
670686
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](
671-
_messages.ModelRequest(parts=output_parts, instructions=instructions)
687+
_messages.ModelRequest(parts=parts, instructions=instructions)
672688
)
673689

674690
async def _handle_text_response(

pydantic_ai_slim/pydantic_ai/_function_schema.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,17 @@
1919
from pydantic_core import SchemaValidator, core_schema
2020
from typing_extensions import ParamSpec, TypeIs, TypeVar
2121

22+
from pydantic_ai.messages import CustomEvent, ToolReturn
23+
2224
from ._griffe import doc_descriptions
2325
from ._run_context import RunContext
24-
from ._utils import check_object_json_schema, is_async_callable, is_model_like, run_in_executor
26+
from ._utils import (
27+
check_object_json_schema,
28+
is_async_callable,
29+
is_async_iterator_callable,
30+
is_model_like,
31+
run_in_executor,
32+
)
2533

2634
if TYPE_CHECKING:
2735
from .tools import DocstringFormat, ObjectJsonSchema
@@ -41,13 +49,31 @@ class FunctionSchema:
4149
# if not None, the function takes a single by that name (besides potentially `info`)
4250
takes_ctx: bool
4351
is_async: bool
52+
is_async_iterator: bool
4453
single_arg_name: str | None = None
4554
positional_fields: list[str] = field(default_factory=list)
4655
var_positional_field: str | None = None
4756

4857
async def call(self, args_dict: dict[str, Any], ctx: RunContext[Any]) -> Any:
4958
args, kwargs = self._call_args(args_dict, ctx)
50-
if self.is_async:
59+
if self.is_async_iterator:
60+
assert ctx.event_stream is not None, (
61+
'RunContext.event_stream needs to be set to use FunctionSchema.call with async iterators'
62+
)
63+
64+
async for event_payload in self.function(*args, **kwargs):
65+
if isinstance(event_payload, ToolReturn):
66+
return event_payload
67+
68+
event = (
69+
cast(CustomEvent, event_payload)
70+
if isinstance(event_payload, CustomEvent)
71+
else CustomEvent(payload=event_payload)
72+
)
73+
await ctx.event_stream.send(event)
74+
# TODO (DouweM): Raise if events are yielded after ToolReturn
75+
return None
76+
elif self.is_async:
5177
function = cast(Callable[[Any], Awaitable[str]], self.function)
5278
return await function(*args, **kwargs)
5379
else:
@@ -221,6 +247,7 @@ def function_schema( # noqa: C901
221247
var_positional_field=var_positional_field,
222248
takes_ctx=takes_ctx,
223249
is_async=is_async_callable(function),
250+
is_async_iterator=is_async_iterator_callable(function),
224251
function=function,
225252
)
226253

pydantic_ai_slim/pydantic_ai/_run_context.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from dataclasses import field
66
from typing import TYPE_CHECKING, Generic
77

8+
from anyio.streams.memory import MemoryObjectSendStream
89
from opentelemetry.trace import NoOpTracer, Tracer
910
from typing_extensions import TypeVar
1011

@@ -36,6 +37,8 @@ class RunContext(Generic[AgentDepsT]):
3637
"""Messages exchanged in the conversation so far."""
3738
tracer: Tracer = field(default_factory=NoOpTracer)
3839
"""The tracer to use for tracing the run."""
40+
event_stream: MemoryObjectSendStream[_messages.CustomEvent] | None = None
41+
"""The event stream to use for handling custom events."""
3942
trace_include_content: bool = False
4043
"""Whether to include the content of the messages in the trace."""
4144
instrumentation_version: int = DEFAULT_INSTRUMENTATION_VERSION

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDe
5050
"""Build a new tool manager for the next run step, carrying over the retries from the current run step."""
5151
if self.ctx is not None:
5252
if ctx.run_step == self.ctx.run_step:
53+
# TODO (DouweM): Refactor to make sure it's always set
54+
55+
if ctx.event_stream and not self.ctx.event_stream:
56+
self.ctx.event_stream = ctx.event_stream
5357
return self
5458

5559
retries = {

pydantic_ai_slim/pydantic_ai/_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,11 @@ def is_async_callable(obj: Any) -> Any:
369369
return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) # type: ignore
370370

371371

372+
def is_async_iterator_callable(obj: Any) -> bool:
373+
"""Check if a callable is an async iterator."""
374+
return inspect.isasyncgenfunction(obj) or (callable(obj) and inspect.isasyncgenfunction(obj.__call__))
375+
376+
372377
def _update_mapped_json_schema_refs(s: dict[str, Any], name_mapping: dict[str, str]) -> None:
373378
"""Update $refs in a schema to use the new names from name_mapping."""
374379
if '$ref' in s:

pydantic_ai_slim/pydantic_ai/ag_ui.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
BaseToolCallPart,
3333
BuiltinToolCallPart,
3434
BuiltinToolReturnPart,
35+
CustomEvent,
3536
FunctionToolResultEvent,
3637
ModelMessage,
3738
ModelRequest,
@@ -431,6 +432,8 @@ async def _agent_stream(run: AgentRun[AgentDepsT, Any]) -> AsyncIterator[BaseEve
431432
if isinstance(event, FunctionToolResultEvent):
432433
async for msg in _handle_tool_result_event(stream_ctx, event):
433434
yield msg
435+
elif isinstance(event, CustomEvent) and isinstance(event.payload, BaseEvent):
436+
yield event.payload
434437

435438

436439
async def _handle_model_request_event( # noqa: C901
@@ -582,6 +585,8 @@ async def _handle_tool_result_event(
582585
content=result.model_response_str(),
583586
)
584587

588+
# TODO (DouweM): Stream `event.content` as if they were user parts?
589+
585590
# Now check for AG-UI events returned by the tool calls.
586591
possible_event = result.metadata or result.content
587592
if isinstance(possible_event, BaseEvent):

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(
7171

7272
async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> _CallToolResult:
7373
name = params.name
74+
# TODO (DouweM): RunContext.event_stream -> call event_stream_handler directly?
7475
ctx = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
7576
try:
7677
tool = (await toolset.get_tools(ctx))[name]

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
from dataclasses import KW_ONLY, dataclass, field, replace
88
from datetime import datetime
99
from mimetypes import guess_type
10-
from typing import TYPE_CHECKING, Annotated, Any, Literal, TypeAlias, cast, overload
10+
from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, TypeAlias, cast, overload
1111

1212
import pydantic
1313
import pydantic_core
1414
from genai_prices import calc_price, types as genai_types
1515
from opentelemetry._events import Event # pyright: ignore[reportPrivateImportUsage]
16-
from typing_extensions import Self, deprecated
16+
from typing_extensions import Self, TypeVar, deprecated
1717

1818
from . import _otel_messages, _utils
1919
from ._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc
@@ -23,6 +23,8 @@
2323
if TYPE_CHECKING:
2424
from .models.instrumented import InstrumentationSettings
2525

26+
EventPayloadT = TypeVar('EventPayloadT', default=Any)
27+
2628

2729
AudioMediaType: TypeAlias = Literal['audio/wav', 'audio/mpeg', 'audio/ogg', 'audio/flac', 'audio/aiff', 'audio/aac']
2830
ImageMediaType: TypeAlias = Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp']
@@ -1724,9 +1726,31 @@ class BuiltinToolResultEvent:
17241726
"""Event type identifier, used as a discriminator."""
17251727

17261728

1729+
@dataclass(repr=False)
1730+
class CustomEvent(Generic[EventPayloadT]):
1731+
"""An event indicating the result of a function tool call."""
1732+
1733+
payload: EventPayloadT
1734+
"""The payload of the custom event."""
1735+
1736+
_: KW_ONLY
1737+
1738+
name: str | None = None
1739+
"""The optional name of the custom event."""
1740+
1741+
id: str | None = None
1742+
"""The optional ID of the custom event."""
1743+
1744+
event_kind: Literal['custom'] = 'custom'
1745+
"""Event type identifier, used as a discriminator."""
1746+
1747+
__repr__ = _utils.dataclasses_no_defaults_repr
1748+
1749+
17271750
HandleResponseEvent = Annotated[
17281751
FunctionToolCallEvent
17291752
| FunctionToolResultEvent
1753+
| CustomEvent
17301754
| BuiltinToolCallEvent # pyright: ignore[reportDeprecated]
17311755
| BuiltinToolResultEvent, # pyright: ignore[reportDeprecated]
17321756
pydantic.Discriminator('event_kind'),

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@ def from_schema(
391391
json_schema=json_schema,
392392
takes_ctx=takes_ctx,
393393
is_async=_utils.is_async_callable(function),
394+
is_async_iterator=_utils.is_async_iterator_callable(function),
394395
)
395396

396397
return cls(

tests/test_ag_ui.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,20 @@ async def send_custom() -> ToolReturn:
192192
)
193193

194194

195+
async def yield_custom() -> AsyncIterator[CustomEvent | ToolReturn]:
196+
yield CustomEvent(
197+
type=EventType.CUSTOM,
198+
name='custom_event1',
199+
value={'key1': 'value1'},
200+
)
201+
yield CustomEvent(
202+
type=EventType.CUSTOM,
203+
name='custom_event2',
204+
value={'key2': 'value2'},
205+
)
206+
yield ToolReturn('Done')
207+
208+
195209
def uuid_str() -> str:
196210
"""Generate a random UUID string."""
197211
return uuid.uuid4().hex
@@ -815,6 +829,73 @@ async def stream_function(
815829
)
816830

817831

832+
async def test_tool_local_yield_events() -> None:
833+
"""Test local tool call that yields multiple events."""
834+
835+
async def stream_function(
836+
messages: list[ModelMessage], agent_info: AgentInfo
837+
) -> AsyncIterator[DeltaToolCalls | str]:
838+
if len(messages) == 1:
839+
# First call - make a tool call
840+
yield {0: DeltaToolCall(name='yield_custom')}
841+
yield {0: DeltaToolCall(json_args='{}')}
842+
else:
843+
# Second call - return text result
844+
yield 'success yield_custom called'
845+
846+
agent = Agent(
847+
model=FunctionModel(stream_function=stream_function),
848+
tools=[yield_custom],
849+
)
850+
851+
run_input = create_input(
852+
UserMessage(
853+
id='msg_1',
854+
content='Please call yield_custom',
855+
),
856+
)
857+
events = await run_and_collect_events(agent, run_input)
858+
859+
assert events == snapshot(
860+
[
861+
{
862+
'type': 'RUN_STARTED',
863+
'threadId': (thread_id := IsSameStr()),
864+
'runId': (run_id := IsSameStr()),
865+
},
866+
{
867+
'type': 'TOOL_CALL_START',
868+
'toolCallId': (tool_call_id := IsSameStr()),
869+
'toolCallName': 'yield_custom',
870+
'parentMessageId': IsStr(),
871+
},
872+
{'type': 'TOOL_CALL_ARGS', 'toolCallId': tool_call_id, 'delta': '{}'},
873+
{'type': 'TOOL_CALL_END', 'toolCallId': tool_call_id},
874+
{'type': 'CUSTOM', 'name': 'custom_event1', 'value': {'key1': 'value1'}},
875+
{'type': 'CUSTOM', 'name': 'custom_event2', 'value': {'key2': 'value2'}},
876+
{
877+
'type': 'TOOL_CALL_RESULT',
878+
'messageId': '8dd33273-d2f5-4e02-8483-8311f0a1cafe',
879+
'toolCallId': 'pyd_ai_219dd870d6e94958a1fca56df511fba4',
880+
'content': 'Done',
881+
'role': 'tool',
882+
},
883+
{'type': 'TEXT_MESSAGE_START', 'messageId': (message_id := IsSameStr()), 'role': 'assistant'},
884+
{
885+
'type': 'TEXT_MESSAGE_CONTENT',
886+
'messageId': message_id,
887+
'delta': 'success yield_custom called',
888+
},
889+
{'type': 'TEXT_MESSAGE_END', 'messageId': message_id},
890+
{
891+
'type': 'RUN_FINISHED',
892+
'threadId': thread_id,
893+
'runId': run_id,
894+
},
895+
]
896+
)
897+
898+
818899
async def test_tool_local_parts() -> None:
819900
"""Test local tool call with streaming/parts."""
820901

0 commit comments

Comments
 (0)