|
12 | 12 | from dataclasses import field, replace |
13 | 13 | from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast |
14 | 14 |
|
| 15 | +import anyio |
15 | 16 | from opentelemetry.trace import Tracer |
16 | 17 | from typing_extensions import TypeVar, assert_never |
17 | 18 |
|
@@ -643,32 +644,47 @@ async def _handle_tool_calls( |
643 | 644 | ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], |
644 | 645 | tool_calls: list[_messages.ToolCallPart], |
645 | 646 | ) -> AsyncIterator[_messages.HandleResponseEvent]: |
| 647 | + send_stream, receive_stream = anyio.create_memory_object_stream[_messages.HandleResponseEvent]() |
| 648 | + |
646 | 649 | run_context = build_run_context(ctx) |
647 | 650 |
|
648 | | - # This will raise errors for any tool name conflicts |
649 | | - ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context) |
| 651 | + async def _run_tool_calls() -> tuple[list[_messages.ModelRequestPart], result.FinalResult[NodeRunEndT] | None]: |
| 652 | + async with send_stream: |
| 653 | + tool_run_context = replace(run_context, event_stream=send_stream) |
650 | 654 |
|
651 | | - output_parts: list[_messages.ModelRequestPart] = [] |
652 | | - output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1) |
| 655 | + # This will raise errors for any tool name conflicts |
| 656 | + ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(tool_run_context) |
653 | 657 |
|
654 | | - async for event in process_tool_calls( |
655 | | - tool_manager=ctx.deps.tool_manager, |
656 | | - tool_calls=tool_calls, |
657 | | - tool_call_results=self.tool_call_results, |
658 | | - final_result=None, |
659 | | - ctx=ctx, |
660 | | - output_parts=output_parts, |
661 | | - output_final_result=output_final_result, |
662 | | - ): |
663 | | - yield event |
| 658 | + output_parts: list[_messages.ModelRequestPart] = [] |
| 659 | + output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1) |
664 | 660 |
|
665 | | - if output_final_result: |
666 | | - final_result = output_final_result[0] |
667 | | - self._next_node = self._handle_final_result(ctx, final_result, output_parts) |
| 661 | + async for event in process_tool_calls( |
| 662 | + tool_manager=ctx.deps.tool_manager, |
| 663 | + tool_calls=tool_calls, |
| 664 | + tool_call_results=self.tool_call_results, |
| 665 | + final_result=None, |
| 666 | + ctx=ctx, |
| 667 | + output_parts=output_parts, |
| 668 | + output_final_result=output_final_result, |
| 669 | + ): |
| 670 | + await send_stream.send(event) |
| 671 | + |
| 672 | + return output_parts, output_final_result[0] if output_final_result else None |
| 673 | + |
| 674 | + task = asyncio.create_task(_run_tool_calls()) |
| 675 | + |
| 676 | + async with receive_stream: |
| 677 | + async for message in receive_stream: |
| 678 | + yield message |
| 679 | + |
| 680 | + parts, final_result = await task |
| 681 | + |
| 682 | + if final_result: |
| 683 | + self._next_node = self._handle_final_result(ctx, final_result, parts) |
668 | 684 | else: |
669 | 685 | instructions = await ctx.deps.get_instructions(run_context) |
670 | 686 | self._next_node = ModelRequestNode[DepsT, NodeRunEndT]( |
671 | | - _messages.ModelRequest(parts=output_parts, instructions=instructions) |
| 687 | + _messages.ModelRequest(parts=parts, instructions=instructions) |
672 | 688 | ) |
673 | 689 |
|
674 | 690 | async def _handle_text_response( |
|
0 commit comments