@@ -562,116 +562,141 @@ async def stream(
562562 async def _run_stream ( # noqa: C901
563563 self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
564564 ) -> AsyncIterator [_messages .HandleResponseEvent ]:
565+ # Ensure that the stream is only run once
565566 if self ._events_iterator is None :
566- # Ensure that the stream is only run once
567+ run_context = build_run_context ( ctx )
567568
568- output_schema = ctx .deps .output_schema
569+ # This will raise errors for any tool name conflicts
570+ ctx .deps .tool_manager = await ctx .deps .tool_manager .for_run_step (run_context )
571+ tool_manager = ctx .deps .tool_manager
569572
570573 async def _run_stream () -> AsyncIterator [_messages .HandleResponseEvent ]: # noqa: C901
571- if not self .model_response .parts :
572- # we got an empty response.
573- # this sometimes happens with anthropic (and perhaps other models)
574- # when the model has already returned text along side tool calls
575- if text_processor := output_schema .text_processor : # pragma: no branch
576- # in this scenario, if text responses are allowed, we return text from the most recent model
577- # response, if any
578- for message in reversed (ctx .state .message_history ):
579- if isinstance (message , _messages .ModelResponse ):
580- text = ''
581- for part in message .parts :
582- if isinstance (part , _messages .TextPart ):
583- text += part .content
584- elif isinstance (part , _messages .BuiltinToolCallPart ):
585- # Text parts before a built-in tool call are essentially thoughts,
586- # not part of the final result output, so we reset the accumulated text
587- text = '' # pragma: no cover
588- if text :
589- try :
590- self ._next_node = await self ._handle_text_response (ctx , text , text_processor )
591- return
592- except ToolRetryError :
593- # If the text from the preview response was invalid, ignore it.
594- pass
595-
596- # Go back to the model request node with an empty request, which means we'll essentially
597- # resubmit the most recent request that resulted in an empty response,
598- # as the empty response and request will not create any items in the API payload,
599- # in the hope the model will return a non-empty response this time.
600- ctx .state .increment_retries (ctx .deps .max_result_retries , model_settings = ctx .deps .model_settings )
601- run_context = build_run_context (ctx )
602- instructions = await ctx .deps .get_instructions (run_context )
603- self ._next_node = ModelRequestNode [DepsT , NodeRunEndT ](
604- _messages .ModelRequest (parts = [], instructions = instructions )
605- )
606- return
607-
608- text = ''
609- tool_calls : list [_messages .ToolCallPart ] = []
610- files : list [_messages .BinaryContent ] = []
611-
612- for part in self .model_response .parts :
613- if isinstance (part , _messages .TextPart ):
614- text += part .content
615- elif isinstance (part , _messages .ToolCallPart ):
616- tool_calls .append (part )
617- elif isinstance (part , _messages .FilePart ):
618- files .append (part .content )
619- elif isinstance (part , _messages .BuiltinToolCallPart ):
620- # Text parts before a built-in tool call are essentially thoughts,
621- # not part of the final result output, so we reset the accumulated text
622- text = ''
623- yield _messages .BuiltinToolCallEvent (part ) # pyright: ignore[reportDeprecated]
624- elif isinstance (part , _messages .BuiltinToolReturnPart ):
625- yield _messages .BuiltinToolResultEvent (part ) # pyright: ignore[reportDeprecated]
626- elif isinstance (part , _messages .ThinkingPart ):
627- pass
628- else :
629- assert_never (part )
630-
631- try :
632- # At the moment, we prioritize at least executing tool calls if they are present.
633- # In the future, we'd consider making this configurable at the agent or run level.
634- # This accounts for cases like anthropic returns that might contain a text response
635- # and a tool call response, where the text response just indicates the tool call will happen.
636- alternatives : list [str ] = []
637- if tool_calls :
638- async for event in self ._handle_tool_calls (ctx , tool_calls ):
639- yield event
640- return
641- elif output_schema .toolset :
642- alternatives .append ('include your response in a tool call' )
643- else :
644- alternatives .append ('call a tool' )
645-
646- if output_schema .allows_image :
647- if image := next ((file for file in files if isinstance (file , _messages .BinaryImage )), None ):
648- self ._next_node = await self ._handle_image_response (ctx , image )
574+ send_stream , receive_stream = anyio .create_memory_object_stream [_messages .HandleResponseEvent ]()
575+
576+ async def _run (): # noqa: C901
577+ async with send_stream :
578+ assert tool_manager .ctx is not None , 'ToolManager.ctx needs to be set'
579+ tool_manager .ctx .event_stream = send_stream
580+
581+ output_schema = ctx .deps .output_schema
582+ if not self .model_response .parts :
583+ # we got an empty response.
584+ # this sometimes happens with anthropic (and perhaps other models)
585+ # when the model has already returned text along side tool calls
586+ if text_processor := output_schema .text_processor : # pragma: no branch
587+ # in this scenario, if text responses are allowed, we return text from the most recent model
588+ # response, if any
589+ for message in reversed (ctx .state .message_history ):
590+ if isinstance (message , _messages .ModelResponse ):
591+ text = ''
592+ for part in message .parts :
593+ if isinstance (part , _messages .TextPart ):
594+ text += part .content
595+ elif isinstance (part , _messages .BuiltinToolCallPart ):
596+ # Text parts before a built-in tool call are essentially thoughts,
597+ # not part of the final result output, so we reset the accumulated text
598+ text = '' # pragma: no cover
599+ if text :
600+ try :
601+ self ._next_node = await self ._handle_text_response (
602+ ctx , text , text_processor
603+ )
604+ return
605+ except ToolRetryError :
606+ # If the text from the preview response was invalid, ignore it.
607+ pass
608+
609+ # Go back to the model request node with an empty request, which means we'll essentially
610+ # resubmit the most recent request that resulted in an empty response,
611+ # as the empty response and request will not create any items in the API payload,
612+ # in the hope the model will return a non-empty response this time.
613+ ctx .state .increment_retries (
614+ ctx .deps .max_result_retries , model_settings = ctx .deps .model_settings
615+ )
616+ run_context = build_run_context (ctx )
617+ instructions = await ctx .deps .get_instructions (run_context )
618+ self ._next_node = ModelRequestNode [DepsT , NodeRunEndT ](
619+ _messages .ModelRequest (parts = [], instructions = instructions )
620+ )
649621 return
650- alternatives .append ('return an image' )
651622
652- if text_processor := output_schema .text_processor :
653- if text :
654- # TODO (DouweM): This could call an output function that yields custom events, but we're not in an event stream here?
655- self ._next_node = await self ._handle_text_response (ctx , text , text_processor )
656- return
657- alternatives .insert (0 , 'return text' )
623+ text = ''
624+ tool_calls : list [_messages .ToolCallPart ] = []
625+ files : list [_messages .BinaryContent ] = []
626+
627+ for part in self .model_response .parts :
628+ if isinstance (part , _messages .TextPart ):
629+ text += part .content
630+ elif isinstance (part , _messages .ToolCallPart ):
631+ tool_calls .append (part )
632+ elif isinstance (part , _messages .FilePart ):
633+ files .append (part .content )
634+ elif isinstance (part , _messages .BuiltinToolCallPart ):
635+ # Text parts before a built-in tool call are essentially thoughts,
636+ # not part of the final result output, so we reset the accumulated text
637+ text = ''
638+ await send_stream .send (_messages .BuiltinToolCallEvent (part )) # pyright: ignore[reportDeprecated]
639+ elif isinstance (part , _messages .BuiltinToolReturnPart ):
640+ await send_stream .send (_messages .BuiltinToolResultEvent (part )) # pyright: ignore[reportDeprecated]
641+ elif isinstance (part , _messages .ThinkingPart ):
642+ pass
643+ else :
644+ assert_never (part )
645+
646+ try :
647+ # At the moment, we prioritize at least executing tool calls if they are present.
648+ # In the future, we'd consider making this configurable at the agent or run level.
649+ # This accounts for cases like anthropic returns that might contain a text response
650+ # and a tool call response, where the text response just indicates the tool call will happen.
651+ alternatives : list [str ] = []
652+ if tool_calls :
653+ async for event in self ._handle_tool_calls (ctx , tool_calls ):
654+ await send_stream .send (event )
655+ return
656+ elif output_schema .toolset :
657+ alternatives .append ('include your response in a tool call' )
658+ else :
659+ alternatives .append ('call a tool' )
658660
659- # handle responses with only parts that don't constitute output.
660- # This can happen with models that support thinking mode when they don't provide
661- # actionable output alongside their thinking content. so we tell the model to try again.
662- m = _messages .RetryPromptPart (
663- content = f'Please { " or " .join (alternatives )} .' ,
664- )
665- raise ToolRetryError (m )
666- except ToolRetryError as e :
667- ctx .state .increment_retries (
668- ctx .deps .max_result_retries , error = e , model_settings = ctx .deps .model_settings
669- )
670- run_context = build_run_context (ctx )
671- instructions = await ctx .deps .get_instructions (run_context )
672- self ._next_node = ModelRequestNode [DepsT , NodeRunEndT ](
673- _messages .ModelRequest (parts = [e .tool_retry ], instructions = instructions )
674- )
661+ if output_schema .allows_image :
662+ if image := next (
663+ (file for file in files if isinstance (file , _messages .BinaryImage )), None
664+ ):
665+ self ._next_node = await self ._handle_image_response (ctx , image )
666+ return
667+ alternatives .append ('return an image' )
668+
669+ if text_processor := output_schema .text_processor :
670+ if text :
671+ # TODO (DouweM): This could call an output function that yields custom events, but we're not in an event stream here?
672+ self ._next_node = await self ._handle_text_response (ctx , text , text_processor )
673+ return
674+ alternatives .insert (0 , 'return text' )
675+
676+ # handle responses with only parts that don't constitute output.
677+ # This can happen with models that support thinking mode when they don't provide
678+ # actionable output alongside their thinking content. so we tell the model to try again.
679+ m = _messages .RetryPromptPart (
680+ content = f'Please { " or " .join (alternatives )} .' ,
681+ )
682+ raise ToolRetryError (m )
683+ except ToolRetryError as e :
684+ ctx .state .increment_retries (
685+ ctx .deps .max_result_retries , error = e , model_settings = ctx .deps .model_settings
686+ )
687+ run_context = build_run_context (ctx )
688+ instructions = await ctx .deps .get_instructions (run_context )
689+ self ._next_node = ModelRequestNode [DepsT , NodeRunEndT ](
690+ _messages .ModelRequest (parts = [e .tool_retry ], instructions = instructions )
691+ )
692+
693+ task = asyncio .create_task (_run ())
694+
695+ async with receive_stream :
696+ async for message in receive_stream :
697+ yield message
698+
699+ await task
675700
676701 self ._events_iterator = _run_stream ()
677702
@@ -685,9 +710,6 @@ async def _handle_tool_calls(
685710 ) -> AsyncIterator [_messages .HandleResponseEvent ]:
686711 run_context = build_run_context (ctx )
687712
688- # This will raise errors for any tool name conflicts
689- ctx .deps .tool_manager = await ctx .deps .tool_manager .for_run_step (run_context )
690-
691713 output_parts : list [_messages .ModelRequestPart ] = []
692714 output_final_result : deque [result .FinalResult [NodeRunEndT ]] = deque (maxlen = 1 )
693715
@@ -937,7 +959,7 @@ async def process_tool_calls( # noqa: C901
937959 output_final_result .append (final_result )
938960
939961
940- async def _call_tools ( # noqa: C901
962+ async def _call_tools (
941963 tool_manager : ToolManager [DepsT ],
942964 tool_calls : list [_messages .ToolCallPart ],
943965 tool_call_results : dict [str , DeferredToolResult ],
@@ -995,45 +1017,30 @@ async def handle_call_or_result(
9951017
9961018 return _messages .FunctionToolResultEvent (tool_part , content = tool_user_content )
9971019
998- send_stream , receive_stream = anyio .create_memory_object_stream [_messages .HandleResponseEvent ]()
999-
1000- async def _run_tools ():
1001- async with send_stream :
1002- assert tool_manager .ctx is not None , 'ToolManager.ctx needs to be set'
1003- tool_manager .ctx .event_stream = send_stream
1004-
1005- if tool_manager .should_call_sequentially (tool_calls ):
1006- for index , call in enumerate (tool_calls ):
1007- if event := await handle_call_or_result (
1008- _call_tool (tool_manager , call , tool_call_results .get (call .tool_call_id )),
1009- index ,
1010- ):
1011- await send_stream .send (event )
1012-
1013- else :
1014- tasks = [
1015- asyncio .create_task (
1016- _call_tool (tool_manager , call , tool_call_results .get (call .tool_call_id )),
1017- name = call .tool_name ,
1018- )
1019- for call in tool_calls
1020- ]
1021-
1022- pending = tasks
1023- while pending :
1024- done , pending = await asyncio .wait (pending , return_when = asyncio .FIRST_COMPLETED )
1025- for task in done :
1026- index = tasks .index (task )
1027- if event := await handle_call_or_result (coro_or_task = task , index = index ):
1028- await send_stream .send (event )
1029-
1030- task = asyncio .create_task (_run_tools ())
1020+ if tool_manager .should_call_sequentially (tool_calls ):
1021+ for index , call in enumerate (tool_calls ):
1022+ if event := await handle_call_or_result (
1023+ _call_tool (tool_manager , call , tool_call_results .get (call .tool_call_id )),
1024+ index ,
1025+ ):
1026+ yield event
10311027
1032- async with receive_stream :
1033- async for message in receive_stream :
1034- yield message
1028+ else :
1029+ tasks = [
1030+ asyncio .create_task (
1031+ _call_tool (tool_manager , call , tool_call_results .get (call .tool_call_id )),
1032+ name = call .tool_name ,
1033+ )
1034+ for call in tool_calls
1035+ ]
10351036
1036- await task
1037+ pending = tasks
1038+ while pending :
1039+ done , pending = await asyncio .wait (pending , return_when = asyncio .FIRST_COMPLETED )
1040+ for task in done :
1041+ index = tasks .index (task )
1042+ if event := await handle_call_or_result (coro_or_task = task , index = index ):
1043+ yield event
10371044
10381045 # We append the results at the end, rather than as they are received, to retain a consistent ordering
10391046 # This is mostly just to simplify testing
0 commit comments