@@ -243,7 +243,7 @@ async def run(
243
243
while True :
244
244
run_step += 1
245
245
with _logfire .span ('preparing model and tools {run_step=}' , run_step = run_step ):
246
- agent_model = await self ._prepare_model (model_used , deps )
246
+ agent_model = await self ._prepare_model (model_used , deps , messages )
247
247
248
248
with _logfire .span ('model request' , run_step = run_step ) as model_req_span :
249
249
model_response , request_cost = await agent_model .request (messages , model_settings )
@@ -255,7 +255,7 @@ async def run(
255
255
cost += request_cost
256
256
257
257
with _logfire .span ('handle model response' , run_step = run_step ) as handle_span :
258
- final_result , response_messages = await self ._handle_model_response (model_response , deps )
258
+ final_result , response_messages = await self ._handle_model_response (model_response , deps , messages )
259
259
260
260
# Add all messages to the conversation
261
261
messages .extend (response_messages )
@@ -391,7 +391,7 @@ async def main():
391
391
run_step += 1
392
392
393
393
with _logfire .span ('preparing model and tools {run_step=}' , run_step = run_step ):
394
- agent_model = await self ._prepare_model (model_used , deps )
394
+ agent_model = await self ._prepare_model (model_used , deps , messages )
395
395
396
396
with _logfire .span ('model request {run_step=}' , run_step = run_step ) as model_req_span :
397
397
async with agent_model .request_stream (messages , model_settings ) as model_response :
@@ -402,7 +402,7 @@ async def main():
402
402
403
403
with _logfire .span ('handle model response' ) as handle_span :
404
404
final_result , response_messages = await self ._handle_streamed_model_response (
405
- model_response , deps
405
+ model_response , deps , messages
406
406
)
407
407
408
408
# Add all messages to the conversation
@@ -773,12 +773,14 @@ async def _get_model(self, model: models.Model | models.KnownModelName | None) -
773
773
774
774
return model_ , mode_selection
775
775
776
- async def _prepare_model (self , model : models .Model , deps : AgentDeps ) -> models .AgentModel :
776
+ async def _prepare_model (
777
+ self , model : models .Model , deps : AgentDeps , messages : list [_messages .Message ]
778
+ ) -> models .AgentModel :
777
779
"""Create building tools and create an agent model."""
778
780
function_tools : list [ToolDefinition ] = []
779
781
780
782
async def add_tool (tool : Tool [AgentDeps ]) -> None :
781
- ctx = RunContext (deps , tool .current_retry , tool .name )
783
+ ctx = RunContext (deps , tool .current_retry , messages , tool .name )
782
784
if tool_def := await tool .prepare_tool_def (ctx ):
783
785
function_tools .append (tool_def )
784
786
@@ -807,7 +809,7 @@ async def _prepare_messages(
807
809
return new_message_index , messages
808
810
809
811
async def _handle_model_response (
810
- self , model_response : _messages .ModelResponse , deps : AgentDeps
812
+ self , model_response : _messages .ModelResponse , deps : AgentDeps , conv_messages : list [ _messages . Message ]
811
813
) -> tuple [_MarkFinalResult [ResultData ] | None , list [_messages .Message ]]:
812
814
"""Process a non-streamed response from the model.
813
815
@@ -824,20 +826,20 @@ async def _handle_model_response(
824
826
825
827
if texts :
826
828
text = '\n \n ' .join (texts )
827
- return await self ._handle_text_response (text , deps )
829
+ return await self ._handle_text_response (text , deps , conv_messages )
828
830
elif tool_calls :
829
- return await self ._handle_structured_response (tool_calls , deps )
831
+ return await self ._handle_structured_response (tool_calls , deps , conv_messages )
830
832
else :
831
833
raise exceptions .UnexpectedModelBehavior ('Received empty model response' )
832
834
833
835
async def _handle_text_response (
834
- self , text : str , deps : AgentDeps
836
+ self , text : str , deps : AgentDeps , conv_messages : list [ _messages . Message ]
835
837
) -> tuple [_MarkFinalResult [ResultData ] | None , list [_messages .Message ]]:
836
838
"""Handle a plain text response from the model for non-streaming responses."""
837
839
if self ._allow_text_result :
838
840
result_data_input = cast (ResultData , text )
839
841
try :
840
- result_data = await self ._validate_result (result_data_input , deps , None )
842
+ result_data = await self ._validate_result (result_data_input , deps , None , conv_messages )
841
843
except _result .ToolRetryError as e :
842
844
self ._incr_result_retry ()
843
845
return None , [e .tool_retry ]
@@ -851,26 +853,24 @@ async def _handle_text_response(
851
853
return None , [response ]
852
854
853
855
async def _handle_structured_response (
854
- self , tool_calls : list [_messages .ToolCallPart ], deps : AgentDeps
856
+ self , tool_calls : list [_messages .ToolCallPart ], deps : AgentDeps , conv_messages : list [ _messages . Message ]
855
857
) -> tuple [_MarkFinalResult [ResultData ] | None , list [_messages .Message ]]:
856
858
"""Handle a structured response containing tool calls from the model for non-streaming responses."""
857
859
assert tool_calls , 'Expected at least one tool call'
858
860
859
861
# First process any final result tool calls
860
- final_result , final_messages = await self ._process_final_tool_calls (tool_calls , deps )
862
+ final_result , final_messages = await self ._process_final_tool_calls (tool_calls , deps , conv_messages )
861
863
862
864
# Then process regular tools based on end strategy
863
865
if self .end_strategy == 'early' and final_result :
864
866
tool_messages = self ._mark_skipped_function_tools (tool_calls )
865
867
else :
866
- tool_messages = await self ._process_function_tools (tool_calls , deps )
868
+ tool_messages = await self ._process_function_tools (tool_calls , deps , conv_messages )
867
869
868
870
return final_result , [* final_messages , * tool_messages ]
869
871
870
872
async def _process_final_tool_calls (
871
- self ,
872
- tool_calls : list [_messages .ToolCallPart ],
873
- deps : AgentDeps ,
873
+ self , tool_calls : list [_messages .ToolCallPart ], deps : AgentDeps , conv_messages : list [_messages .Message ]
874
874
) -> tuple [_MarkFinalResult [ResultData ] | None , list [_messages .Message ]]:
875
875
"""Process any final result tool calls and return the first valid result."""
876
876
if not self ._result_schema :
@@ -888,7 +888,7 @@ async def _process_final_tool_calls(
888
888
# This is the first result tool - try to use it
889
889
try :
890
890
result_data = result_tool .validate (call )
891
- result_data = await self ._validate_result (result_data , deps , call )
891
+ result_data = await self ._validate_result (result_data , deps , call , conv_messages )
892
892
except _result .ToolRetryError as e :
893
893
self ._incr_result_retry ()
894
894
messages .append (e .tool_retry )
@@ -914,17 +914,15 @@ async def _process_final_tool_calls(
914
914
return final_result , messages
915
915
916
916
async def _process_function_tools (
917
- self ,
918
- tool_calls : list [_messages .ToolCallPart ],
919
- deps : AgentDeps ,
917
+ self , tool_calls : list [_messages .ToolCallPart ], deps : AgentDeps , conv_messages : list [_messages .Message ]
920
918
) -> list [_messages .Message ]:
921
919
"""Process function (non-final) tool calls in parallel."""
922
920
messages : list [_messages .Message ] = []
923
921
tasks : list [asyncio .Task [_messages .Message ]] = []
924
922
925
923
for call in tool_calls :
926
924
if tool := self ._function_tools .get (call .tool_name ):
927
- tasks .append (asyncio .create_task (tool .run (deps , call ), name = call .tool_name ))
925
+ tasks .append (asyncio .create_task (tool .run (deps , call , conv_messages ), name = call .tool_name ))
928
926
elif self ._result_schema is None or call .tool_name not in self ._result_schema .tools :
929
927
messages .append (self ._unknown_tool (call .tool_name ))
930
928
@@ -958,7 +956,7 @@ def _mark_skipped_function_tools(
958
956
return messages
959
957
960
958
async def _handle_streamed_model_response (
961
- self , model_response : models .EitherStreamedResponse , deps : AgentDeps
959
+ self , model_response : models .EitherStreamedResponse , deps : AgentDeps , conv_messages : list [ _messages . Message ]
962
960
) -> tuple [_MarkFinalResult [models .EitherStreamedResponse ] | None , list [_messages .Message ]]:
963
961
"""Process a streamed response from the model.
964
962
@@ -1015,7 +1013,7 @@ async def _handle_streamed_model_response(
1015
1013
if isinstance (item , _messages .ToolCallPart ):
1016
1014
call = item
1017
1015
if tool := self ._function_tools .get (call .tool_name ):
1018
- tasks .append (asyncio .create_task (tool .run (deps , call ), name = call .tool_name ))
1016
+ tasks .append (asyncio .create_task (tool .run (deps , call , conv_messages ), name = call .tool_name ))
1019
1017
else :
1020
1018
messages .append (self ._unknown_tool (call .tool_name ))
1021
1019
@@ -1025,10 +1023,16 @@ async def _handle_streamed_model_response(
1025
1023
return None , messages
1026
1024
1027
1025
async def _validate_result (
1028
- self , result_data : ResultData , deps : AgentDeps , tool_call : _messages .ToolCallPart | None
1026
+ self ,
1027
+ result_data : ResultData ,
1028
+ deps : AgentDeps ,
1029
+ tool_call : _messages .ToolCallPart | None ,
1030
+ conv_messages : list [_messages .Message ],
1029
1031
) -> ResultData :
1030
1032
for validator in self ._result_validators :
1031
- result_data = await validator .validate (result_data , deps , self ._current_result_retry , tool_call )
1033
+ result_data = await validator .validate (
1034
+ result_data , deps , self ._current_result_retry , tool_call , conv_messages
1035
+ )
1032
1036
return result_data
1033
1037
1034
1038
def _incr_result_retry (self ) -> None :
0 commit comments