Skip to content

Commit 61df410

Browse files
committed
Support streaming events from output function (e.g. agent handoff)
1 parent 69be302 commit 61df410

File tree

5 files changed

+184
-154
lines changed

5 files changed

+184
-154
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 151 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -562,116 +562,141 @@ async def stream(
562562
async def _run_stream( # noqa: C901
563563
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
564564
) -> AsyncIterator[_messages.HandleResponseEvent]:
565+
# Ensure that the stream is only run once
565566
if self._events_iterator is None:
566-
# Ensure that the stream is only run once
567+
run_context = build_run_context(ctx)
567568

568-
output_schema = ctx.deps.output_schema
569+
# This will raise errors for any tool name conflicts
570+
ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)
571+
tool_manager = ctx.deps.tool_manager
569572

570573
async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa: C901
571-
if not self.model_response.parts:
572-
# we got an empty response.
573-
# this sometimes happens with anthropic (and perhaps other models)
574-
# when the model has already returned text along side tool calls
575-
if text_processor := output_schema.text_processor: # pragma: no branch
576-
# in this scenario, if text responses are allowed, we return text from the most recent model
577-
# response, if any
578-
for message in reversed(ctx.state.message_history):
579-
if isinstance(message, _messages.ModelResponse):
580-
text = ''
581-
for part in message.parts:
582-
if isinstance(part, _messages.TextPart):
583-
text += part.content
584-
elif isinstance(part, _messages.BuiltinToolCallPart):
585-
# Text parts before a built-in tool call are essentially thoughts,
586-
# not part of the final result output, so we reset the accumulated text
587-
text = '' # pragma: no cover
588-
if text:
589-
try:
590-
self._next_node = await self._handle_text_response(ctx, text, text_processor)
591-
return
592-
except ToolRetryError:
593-
# If the text from the preview response was invalid, ignore it.
594-
pass
595-
596-
# Go back to the model request node with an empty request, which means we'll essentially
597-
# resubmit the most recent request that resulted in an empty response,
598-
# as the empty response and request will not create any items in the API payload,
599-
# in the hope the model will return a non-empty response this time.
600-
ctx.state.increment_retries(ctx.deps.max_result_retries, model_settings=ctx.deps.model_settings)
601-
run_context = build_run_context(ctx)
602-
instructions = await ctx.deps.get_instructions(run_context)
603-
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](
604-
_messages.ModelRequest(parts=[], instructions=instructions)
605-
)
606-
return
607-
608-
text = ''
609-
tool_calls: list[_messages.ToolCallPart] = []
610-
files: list[_messages.BinaryContent] = []
611-
612-
for part in self.model_response.parts:
613-
if isinstance(part, _messages.TextPart):
614-
text += part.content
615-
elif isinstance(part, _messages.ToolCallPart):
616-
tool_calls.append(part)
617-
elif isinstance(part, _messages.FilePart):
618-
files.append(part.content)
619-
elif isinstance(part, _messages.BuiltinToolCallPart):
620-
# Text parts before a built-in tool call are essentially thoughts,
621-
# not part of the final result output, so we reset the accumulated text
622-
text = ''
623-
yield _messages.BuiltinToolCallEvent(part) # pyright: ignore[reportDeprecated]
624-
elif isinstance(part, _messages.BuiltinToolReturnPart):
625-
yield _messages.BuiltinToolResultEvent(part) # pyright: ignore[reportDeprecated]
626-
elif isinstance(part, _messages.ThinkingPart):
627-
pass
628-
else:
629-
assert_never(part)
630-
631-
try:
632-
# At the moment, we prioritize at least executing tool calls if they are present.
633-
# In the future, we'd consider making this configurable at the agent or run level.
634-
# This accounts for cases like anthropic returns that might contain a text response
635-
# and a tool call response, where the text response just indicates the tool call will happen.
636-
alternatives: list[str] = []
637-
if tool_calls:
638-
async for event in self._handle_tool_calls(ctx, tool_calls):
639-
yield event
640-
return
641-
elif output_schema.toolset:
642-
alternatives.append('include your response in a tool call')
643-
else:
644-
alternatives.append('call a tool')
645-
646-
if output_schema.allows_image:
647-
if image := next((file for file in files if isinstance(file, _messages.BinaryImage)), None):
648-
self._next_node = await self._handle_image_response(ctx, image)
574+
send_stream, receive_stream = anyio.create_memory_object_stream[_messages.HandleResponseEvent]()
575+
576+
async def _run(): # noqa: C901
577+
async with send_stream:
578+
assert tool_manager.ctx is not None, 'ToolManager.ctx needs to be set'
579+
tool_manager.ctx.event_stream = send_stream
580+
581+
output_schema = ctx.deps.output_schema
582+
if not self.model_response.parts:
583+
# we got an empty response.
584+
# this sometimes happens with anthropic (and perhaps other models)
585+
# when the model has already returned text along side tool calls
586+
if text_processor := output_schema.text_processor: # pragma: no branch
587+
# in this scenario, if text responses are allowed, we return text from the most recent model
588+
# response, if any
589+
for message in reversed(ctx.state.message_history):
590+
if isinstance(message, _messages.ModelResponse):
591+
text = ''
592+
for part in message.parts:
593+
if isinstance(part, _messages.TextPart):
594+
text += part.content
595+
elif isinstance(part, _messages.BuiltinToolCallPart):
596+
# Text parts before a built-in tool call are essentially thoughts,
597+
# not part of the final result output, so we reset the accumulated text
598+
text = '' # pragma: no cover
599+
if text:
600+
try:
601+
self._next_node = await self._handle_text_response(
602+
ctx, text, text_processor
603+
)
604+
return
605+
except ToolRetryError:
606+
# If the text from the preview response was invalid, ignore it.
607+
pass
608+
609+
# Go back to the model request node with an empty request, which means we'll essentially
610+
# resubmit the most recent request that resulted in an empty response,
611+
# as the empty response and request will not create any items in the API payload,
612+
# in the hope the model will return a non-empty response this time.
613+
ctx.state.increment_retries(
614+
ctx.deps.max_result_retries, model_settings=ctx.deps.model_settings
615+
)
616+
run_context = build_run_context(ctx)
617+
instructions = await ctx.deps.get_instructions(run_context)
618+
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](
619+
_messages.ModelRequest(parts=[], instructions=instructions)
620+
)
649621
return
650-
alternatives.append('return an image')
651622

652-
if text_processor := output_schema.text_processor:
653-
if text:
654-
# TODO (DouweM): This could call an output function that yields custom events, but we're not in an event stream here?
655-
self._next_node = await self._handle_text_response(ctx, text, text_processor)
656-
return
657-
alternatives.insert(0, 'return text')
623+
text = ''
624+
tool_calls: list[_messages.ToolCallPart] = []
625+
files: list[_messages.BinaryContent] = []
626+
627+
for part in self.model_response.parts:
628+
if isinstance(part, _messages.TextPart):
629+
text += part.content
630+
elif isinstance(part, _messages.ToolCallPart):
631+
tool_calls.append(part)
632+
elif isinstance(part, _messages.FilePart):
633+
files.append(part.content)
634+
elif isinstance(part, _messages.BuiltinToolCallPart):
635+
# Text parts before a built-in tool call are essentially thoughts,
636+
# not part of the final result output, so we reset the accumulated text
637+
text = ''
638+
await send_stream.send(_messages.BuiltinToolCallEvent(part)) # pyright: ignore[reportDeprecated]
639+
elif isinstance(part, _messages.BuiltinToolReturnPart):
640+
await send_stream.send(_messages.BuiltinToolResultEvent(part)) # pyright: ignore[reportDeprecated]
641+
elif isinstance(part, _messages.ThinkingPart):
642+
pass
643+
else:
644+
assert_never(part)
645+
646+
try:
647+
# At the moment, we prioritize at least executing tool calls if they are present.
648+
# In the future, we'd consider making this configurable at the agent or run level.
649+
# This accounts for cases like anthropic returns that might contain a text response
650+
# and a tool call response, where the text response just indicates the tool call will happen.
651+
alternatives: list[str] = []
652+
if tool_calls:
653+
async for event in self._handle_tool_calls(ctx, tool_calls):
654+
await send_stream.send(event)
655+
return
656+
elif output_schema.toolset:
657+
alternatives.append('include your response in a tool call')
658+
else:
659+
alternatives.append('call a tool')
658660

659-
# handle responses with only parts that don't constitute output.
660-
# This can happen with models that support thinking mode when they don't provide
661-
# actionable output alongside their thinking content. so we tell the model to try again.
662-
m = _messages.RetryPromptPart(
663-
content=f'Please {" or ".join(alternatives)}.',
664-
)
665-
raise ToolRetryError(m)
666-
except ToolRetryError as e:
667-
ctx.state.increment_retries(
668-
ctx.deps.max_result_retries, error=e, model_settings=ctx.deps.model_settings
669-
)
670-
run_context = build_run_context(ctx)
671-
instructions = await ctx.deps.get_instructions(run_context)
672-
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](
673-
_messages.ModelRequest(parts=[e.tool_retry], instructions=instructions)
674-
)
661+
if output_schema.allows_image:
662+
if image := next(
663+
(file for file in files if isinstance(file, _messages.BinaryImage)), None
664+
):
665+
self._next_node = await self._handle_image_response(ctx, image)
666+
return
667+
alternatives.append('return an image')
668+
669+
if text_processor := output_schema.text_processor:
670+
if text:
671+
# TODO (DouweM): This could call an output function that yields custom events, but we're not in an event stream here?
672+
self._next_node = await self._handle_text_response(ctx, text, text_processor)
673+
return
674+
alternatives.insert(0, 'return text')
675+
676+
# handle responses with only parts that don't constitute output.
677+
# This can happen with models that support thinking mode when they don't provide
678+
# actionable output alongside their thinking content. so we tell the model to try again.
679+
m = _messages.RetryPromptPart(
680+
content=f'Please {" or ".join(alternatives)}.',
681+
)
682+
raise ToolRetryError(m)
683+
except ToolRetryError as e:
684+
ctx.state.increment_retries(
685+
ctx.deps.max_result_retries, error=e, model_settings=ctx.deps.model_settings
686+
)
687+
run_context = build_run_context(ctx)
688+
instructions = await ctx.deps.get_instructions(run_context)
689+
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](
690+
_messages.ModelRequest(parts=[e.tool_retry], instructions=instructions)
691+
)
692+
693+
task = asyncio.create_task(_run())
694+
695+
async with receive_stream:
696+
async for message in receive_stream:
697+
yield message
698+
699+
await task
675700

676701
self._events_iterator = _run_stream()
677702

@@ -685,9 +710,6 @@ async def _handle_tool_calls(
685710
) -> AsyncIterator[_messages.HandleResponseEvent]:
686711
run_context = build_run_context(ctx)
687712

688-
# This will raise errors for any tool name conflicts
689-
ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)
690-
691713
output_parts: list[_messages.ModelRequestPart] = []
692714
output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1)
693715

@@ -937,7 +959,7 @@ async def process_tool_calls( # noqa: C901
937959
output_final_result.append(final_result)
938960

939961

940-
async def _call_tools( # noqa: C901
962+
async def _call_tools(
941963
tool_manager: ToolManager[DepsT],
942964
tool_calls: list[_messages.ToolCallPart],
943965
tool_call_results: dict[str, DeferredToolResult],
@@ -995,45 +1017,30 @@ async def handle_call_or_result(
9951017

9961018
return _messages.FunctionToolResultEvent(tool_part, content=tool_user_content)
9971019

998-
send_stream, receive_stream = anyio.create_memory_object_stream[_messages.HandleResponseEvent]()
999-
1000-
async def _run_tools():
1001-
async with send_stream:
1002-
assert tool_manager.ctx is not None, 'ToolManager.ctx needs to be set'
1003-
tool_manager.ctx.event_stream = send_stream
1004-
1005-
if tool_manager.should_call_sequentially(tool_calls):
1006-
for index, call in enumerate(tool_calls):
1007-
if event := await handle_call_or_result(
1008-
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)),
1009-
index,
1010-
):
1011-
await send_stream.send(event)
1012-
1013-
else:
1014-
tasks = [
1015-
asyncio.create_task(
1016-
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)),
1017-
name=call.tool_name,
1018-
)
1019-
for call in tool_calls
1020-
]
1021-
1022-
pending = tasks
1023-
while pending:
1024-
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
1025-
for task in done:
1026-
index = tasks.index(task)
1027-
if event := await handle_call_or_result(coro_or_task=task, index=index):
1028-
await send_stream.send(event)
1029-
1030-
task = asyncio.create_task(_run_tools())
1020+
if tool_manager.should_call_sequentially(tool_calls):
1021+
for index, call in enumerate(tool_calls):
1022+
if event := await handle_call_or_result(
1023+
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)),
1024+
index,
1025+
):
1026+
yield event
10311027

1032-
async with receive_stream:
1033-
async for message in receive_stream:
1034-
yield message
1028+
else:
1029+
tasks = [
1030+
asyncio.create_task(
1031+
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)),
1032+
name=call.tool_name,
1033+
)
1034+
for call in tool_calls
1035+
]
10351036

1036-
await task
1037+
pending = tasks
1038+
while pending:
1039+
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
1040+
for task in done:
1041+
index = tasks.index(task)
1042+
if event := await handle_call_or_result(coro_or_task=task, index=index):
1043+
yield event
10371044

10381045
# We append the results at the end, rather than as they are received, to retain a consistent ordering
10391046
# This is mostly just to simplify testing

0 commit comments

Comments
 (0)