@@ -215,23 +215,24 @@ async def run(
215
215
cost += request_cost
216
216
217
217
with _logfire .span ('handle model response' , run_step = run_step ) as handle_span :
218
- either = await self ._handle_model_response (model_response , deps )
218
+ final_result , response_messages = await self ._handle_model_response (model_response , deps )
219
219
220
- if isinstance (either , _MarkFinalResult ):
221
- # we have a final result, end the conversation
222
- result_data = either .data
220
+ # Add all messages to the conversation
221
+ messages .extend (response_messages )
222
+
223
+ # Check if we got a final result
224
+ if final_result is not None :
225
+ result_data = final_result .data
223
226
run_span .set_attribute ('all_messages' , messages )
224
227
run_span .set_attribute ('cost' , cost )
225
228
handle_span .set_attribute ('result' , result_data )
226
229
handle_span .message = 'handle model response -> final result'
227
230
return result .RunResult (messages , new_message_index , result_data , cost )
228
231
else :
229
232
# continue the conversation
230
- tool_responses = either
231
- handle_span .set_attribute ('tool_responses' , tool_responses )
232
- response_msgs = ' ' .join (m .role for m in tool_responses )
233
+ handle_span .set_attribute ('tool_responses' , response_messages )
234
+ response_msgs = ' ' .join (r .role for r in response_messages )
233
235
handle_span .message = f'handle model response -> { response_msgs } '
234
- messages .extend (tool_responses )
235
236
236
237
def run_sync (
237
238
self ,
@@ -324,10 +325,16 @@ async def run_stream(
324
325
model_req_span .__exit__ (None , None , None )
325
326
326
327
with _logfire .span ('handle model response' ) as handle_span :
327
- either = await self ._handle_streamed_model_response (model_response , deps )
328
+ final_result , response_messages = await self ._handle_streamed_model_response (
329
+ model_response , deps
330
+ )
331
+
332
+ # Add all messages to the conversation
333
+ messages .extend (response_messages )
328
334
329
- if isinstance (either , _MarkFinalResult ):
330
- result_stream = either .data
335
+ # Check if we got a final result
336
+ if final_result is not None :
337
+ result_stream = final_result .data
331
338
run_span .set_attribute ('all_messages' , messages )
332
339
handle_span .set_attribute ('result_type' , result_stream .__class__ .__name__ )
333
340
handle_span .message = 'handle model response -> final result'
@@ -343,11 +350,10 @@ async def run_stream(
343
350
)
344
351
return
345
352
else :
346
- tool_responses = either
347
- handle_span .set_attribute ('tool_responses' , tool_responses )
348
- response_msgs = ' ' .join (m .role for m in tool_responses )
353
+ # continue the conversation
354
+ handle_span .set_attribute ('tool_responses' , response_messages )
355
+ response_msgs = ' ' .join (r .role for r in response_messages )
349
356
handle_span .message = f'handle model response -> { response_msgs } '
350
- messages .extend (tool_responses )
351
357
# the model_response should have been fully streamed by now, we can add it's cost
352
358
cost += model_response .cost ()
353
359
@@ -725,11 +731,11 @@ async def _prepare_messages(
725
731
726
732
async def _handle_model_response (
727
733
self , model_response : _messages .ModelAnyResponse , deps : AgentDeps
728
- ) -> _MarkFinalResult [ResultData ] | list [_messages .Message ]:
734
+ ) -> tuple [ _MarkFinalResult [ResultData ] | None , list [_messages .Message ] ]:
729
735
"""Process a non-streamed response from the model.
730
736
731
737
Returns:
732
- Return `Either` — left: final result data, right: list of messages to send back to the model .
738
+ A tuple of `(final_result, messages)`. If `final_result` is not `None`, the conversation should end .
733
739
"""
734
740
if model_response .role == 'model-text-response' :
735
741
# plain string response
@@ -739,15 +745,15 @@ async def _handle_model_response(
739
745
result_data = await self ._validate_result (result_data_input , deps , None )
740
746
except _result .ToolRetryError as e :
741
747
self ._incr_result_retry ()
742
- return [e .tool_retry ]
748
+ return None , [e .tool_retry ]
743
749
else :
744
- return _MarkFinalResult (result_data )
750
+ return _MarkFinalResult (result_data ), []
745
751
else :
746
752
self ._incr_result_retry ()
747
753
response = _messages .RetryPrompt (
748
754
content = 'Plain text responses are not permitted, please call one of the functions instead.' ,
749
755
)
750
- return [response ]
756
+ return None , [response ]
751
757
elif model_response .role == 'model-structured-response' :
752
758
if self ._result_schema is not None :
753
759
# if there's a result schema, and any of the calls match one of its tools, return the result
@@ -759,9 +765,15 @@ async def _handle_model_response(
759
765
result_data = await self ._validate_result (result_data , deps , call )
760
766
except _result .ToolRetryError as e :
761
767
self ._incr_result_retry ()
762
- return [e .tool_retry ]
768
+ return None , [e .tool_retry ]
763
769
else :
764
- return _MarkFinalResult (result_data )
770
+ # Add a ToolReturn message for the schema tool call
771
+ tool_return = _messages .ToolReturn (
772
+ tool_name = call .tool_name ,
773
+ content = 'Final result processed.' ,
774
+ tool_id = call .tool_id ,
775
+ )
776
+ return _MarkFinalResult (result_data ), [tool_return ]
765
777
766
778
if not model_response .calls :
767
779
raise exceptions .UnexpectedModelBehavior ('Received empty tool call message' )
@@ -776,26 +788,24 @@ async def _handle_model_response(
776
788
messages .append (self ._unknown_tool (call .tool_name ))
777
789
778
790
with _logfire .span ('running {tools=}' , tools = [t .get_name () for t in tasks ]):
779
- messages += await asyncio .gather (* tasks )
780
- return messages
791
+ task_results : Sequence [_messages .Message ] = await asyncio .gather (* tasks )
792
+ messages .extend (task_results )
793
+ return None , messages
781
794
else :
782
795
assert_never (model_response )
783
796
784
797
async def _handle_streamed_model_response (
785
798
self , model_response : models .EitherStreamedResponse , deps : AgentDeps
786
- ) -> _MarkFinalResult [models .EitherStreamedResponse ] | list [_messages .Message ]:
799
+ ) -> tuple [ _MarkFinalResult [models .EitherStreamedResponse ] | None , list [_messages .Message ] ]:
787
800
"""Process a streamed response from the model.
788
801
789
- TODO: change the response type to `models.EitherStreamedResponse | list[_messages.Message]` once we drop 3.9
790
- (with 3.9 we get `TypeError: Subscripted generics cannot be used with class and instance checks`)
791
-
792
802
Returns:
793
- Return `Either` — left: final result data, right: list of messages to send back to the model .
803
+ A tuple of (final_result, messages). If final_result is not None, the conversation should end .
794
804
"""
795
805
if isinstance (model_response , models .StreamTextResponse ):
796
806
# plain string response
797
807
if self ._allow_text_result :
798
- return _MarkFinalResult (model_response )
808
+ return _MarkFinalResult (model_response ), []
799
809
else :
800
810
self ._incr_result_retry ()
801
811
response = _messages .RetryPrompt (
@@ -805,7 +815,7 @@ async def _handle_streamed_model_response(
805
815
async for _ in model_response :
806
816
pass
807
817
808
- return [response ]
818
+ return None , [response ]
809
819
else :
810
820
assert isinstance (model_response , models .StreamStructuredResponse ), f'Unexpected response: { model_response } '
811
821
if self ._result_schema is not None :
@@ -819,8 +829,14 @@ async def _handle_streamed_model_response(
819
829
break
820
830
structured_msg = model_response .get ()
821
831
822
- if self ._result_schema .find_tool (structured_msg ):
823
- return _MarkFinalResult (model_response )
832
+ if match := self ._result_schema .find_tool (structured_msg ):
833
+ call , _ = match
834
+ tool_return = _messages .ToolReturn (
835
+ tool_name = call .tool_name ,
836
+ content = 'Final result processed.' ,
837
+ tool_id = call .tool_id ,
838
+ )
839
+ return _MarkFinalResult (model_response ), [tool_return ]
824
840
825
841
# the model is calling a tool function, consume the response to get the next message
826
842
async for _ in model_response :
@@ -839,8 +855,9 @@ async def _handle_streamed_model_response(
839
855
messages .append (self ._unknown_tool (call .tool_name ))
840
856
841
857
with _logfire .span ('running {tools=}' , tools = [t .get_name () for t in tasks ]):
842
- messages += await asyncio .gather (* tasks )
843
- return messages
858
+ task_results : Sequence [_messages .Message ] = await asyncio .gather (* tasks )
859
+ messages .extend (task_results )
860
+ return None , messages
844
861
845
862
async def _validate_result (
846
863
self , result_data : ResultData , deps : AgentDeps , tool_call : _messages .ToolCall | None
@@ -912,6 +929,8 @@ class _MarkFinalResult(Generic[ResultData]):
912
929
"""Marker class to indicate that the result is the final result.
913
930
914
931
This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultData` directly.
932
+
933
+ It also avoids problems in the case where the result type is itself `None`, but is set.
915
934
"""
916
935
917
936
data : ResultData
0 commit comments