@@ -104,7 +104,6 @@ class Agent(Generic[AgentDeps, ResultData]):
104
104
_system_prompt_functions : list [_system_prompt .SystemPromptRunner [AgentDeps ]] = field (repr = False )
105
105
_deps_type : type [AgentDeps ] = field (repr = False )
106
106
_max_result_retries : int = field (repr = False )
107
- _current_result_retry : int = field (repr = False )
108
107
_override_deps : _utils .Option [AgentDeps ] = field (default = None , repr = False )
109
108
_override_model : _utils .Option [models .Model ] = field (default = None , repr = False )
110
109
@@ -180,7 +179,6 @@ def __init__(
180
179
self ._deps_type = deps_type
181
180
self ._system_prompt_functions = []
182
181
self ._max_result_retries = result_retries if result_retries is not None else retries
183
- self ._current_result_retry = 0
184
182
self ._result_validators = []
185
183
186
184
async def run (
@@ -234,7 +232,9 @@ async def run(
234
232
model_name = model_used .name (),
235
233
agent_name = self .name or 'agent' ,
236
234
) as run_span :
237
- self .last_run_messages = messages = await self ._prepare_messages (deps , user_prompt , message_history )
235
+ run_context = RunContext (deps , 0 , [], None , model_used )
236
+ messages = await self ._prepare_messages (user_prompt , message_history , run_context )
237
+ self .last_run_messages = run_context .messages = messages
238
238
239
239
for tool in self ._function_tools .values ():
240
240
tool .current_retry = 0
@@ -249,7 +249,7 @@ async def run(
249
249
250
250
run_step += 1
251
251
with _logfire .span ('preparing model and tools {run_step=}' , run_step = run_step ):
252
- agent_model = await self ._prepare_model (model_used , deps , messages )
252
+ agent_model = await self ._prepare_model (run_context )
253
253
254
254
with _logfire .span ('model request' , run_step = run_step ) as model_req_span :
255
255
model_response , request_usage = await agent_model .request (messages , model_settings )
@@ -262,7 +262,7 @@ async def run(
262
262
usage_limits .check_tokens (request_usage )
263
263
264
264
with _logfire .span ('handle model response' , run_step = run_step ) as handle_span :
265
- final_result , tool_responses = await self ._handle_model_response (model_response , deps , messages )
265
+ final_result , tool_responses = await self ._handle_model_response (model_response , run_context )
266
266
267
267
if tool_responses :
268
268
# Add parts to the conversation as a new message
@@ -391,7 +391,9 @@ async def main():
391
391
model_name = model_used .name (),
392
392
agent_name = self .name or 'agent' ,
393
393
) as run_span :
394
- self .last_run_messages = messages = await self ._prepare_messages (deps , user_prompt , message_history )
394
+ run_context = RunContext (deps , 0 , [], None , model_used )
395
+ messages = await self ._prepare_messages (user_prompt , message_history , run_context )
396
+ self .last_run_messages = run_context .messages = messages
395
397
396
398
for tool in self ._function_tools .values ():
397
399
tool .current_retry = 0
@@ -406,7 +408,7 @@ async def main():
406
408
usage_limits .check_before_request (usage )
407
409
408
410
with _logfire .span ('preparing model and tools {run_step=}' , run_step = run_step ):
409
- agent_model = await self ._prepare_model (model_used , deps , messages )
411
+ agent_model = await self ._prepare_model (run_context )
410
412
411
413
with _logfire .span ('model request {run_step=}' , run_step = run_step ) as model_req_span :
412
414
async with agent_model .request_stream (messages , model_settings ) as model_response :
@@ -417,9 +419,7 @@ async def main():
417
419
model_req_span .__exit__ (None , None , None )
418
420
419
421
with _logfire .span ('handle model response' ) as handle_span :
420
- maybe_final_result = await self ._handle_streamed_model_response (
421
- model_response , deps , messages
422
- )
422
+ maybe_final_result = await self ._handle_streamed_model_response (model_response , run_context )
423
423
424
424
# Check if we got a final result
425
425
if isinstance (maybe_final_result , _MarkFinalResult ):
@@ -439,7 +439,7 @@ async def on_complete():
439
439
part for part in last_message .parts if isinstance (part , _messages .ToolCallPart )
440
440
]
441
441
parts = await self ._process_function_tools (
442
- tool_calls , result_tool_name , deps , messages
442
+ tool_calls , result_tool_name , run_context
443
443
)
444
444
if parts :
445
445
messages .append (_messages .ModelRequest (parts ))
@@ -452,7 +452,7 @@ async def on_complete():
452
452
usage_limits ,
453
453
result_stream ,
454
454
self ._result_schema ,
455
- deps ,
455
+ run_context ,
456
456
self ._result_validators ,
457
457
result_tool_name ,
458
458
on_complete ,
@@ -815,41 +815,39 @@ async def _get_model(self, model: models.Model | models.KnownModelName | None) -
815
815
816
816
return model_ , mode_selection
817
817
818
- async def _prepare_model (
819
- self , model : models .Model , deps : AgentDeps , messages : list [_messages .ModelMessage ]
820
- ) -> models .AgentModel :
821
- """Create building tools and create an agent model."""
818
+ async def _prepare_model (self , run_context : RunContext [AgentDeps ]) -> models .AgentModel :
819
+ """Build tools and create an agent model."""
822
820
function_tools : list [ToolDefinition ] = []
823
821
824
822
async def add_tool (tool : Tool [AgentDeps ]) -> None :
825
- ctx = RunContext ( deps , tool .current_retry , messages , tool .name )
823
+ ctx = run_context . replace_with ( retry = tool .current_retry , tool_name = tool .name )
826
824
if tool_def := await tool .prepare_tool_def (ctx ):
827
825
function_tools .append (tool_def )
828
826
829
827
await asyncio .gather (* map (add_tool , self ._function_tools .values ()))
830
828
831
- return await model .agent_model (
829
+ return await run_context . model .agent_model (
832
830
function_tools = function_tools ,
833
831
allow_text_result = self ._allow_text_result ,
834
832
result_tools = self ._result_schema .tool_defs () if self ._result_schema is not None else [],
835
833
)
836
834
837
835
async def _prepare_messages (
838
- self , deps : AgentDeps , user_prompt : str , message_history : list [_messages .ModelMessage ] | None
836
+ self , user_prompt : str , message_history : list [_messages .ModelMessage ] | None , run_context : RunContext [ AgentDeps ]
839
837
) -> list [_messages .ModelMessage ]:
840
838
if message_history :
841
839
# shallow copy messages
842
840
messages = message_history .copy ()
843
841
messages .append (_messages .ModelRequest ([_messages .UserPromptPart (user_prompt )]))
844
842
else :
845
- parts = await self ._sys_parts (deps )
843
+ parts = await self ._sys_parts (run_context )
846
844
parts .append (_messages .UserPromptPart (user_prompt ))
847
845
messages : list [_messages .ModelMessage ] = [_messages .ModelRequest (parts )]
848
846
849
847
return messages
850
848
851
849
async def _handle_model_response (
852
- self , model_response : _messages .ModelResponse , deps : AgentDeps , conv_messages : list [ _messages . ModelMessage ]
850
+ self , model_response : _messages .ModelResponse , run_context : RunContext [ AgentDeps ]
853
851
) -> tuple [_MarkFinalResult [ResultData ] | None , list [_messages .ModelRequestPart ]]:
854
852
"""Process a non-streamed response from the model.
855
853
@@ -868,34 +866,34 @@ async def _handle_model_response(
868
866
869
867
if texts :
870
868
text = '\n \n ' .join (texts )
871
- return await self ._handle_text_response (text , deps , conv_messages )
869
+ return await self ._handle_text_response (text , run_context )
872
870
elif tool_calls :
873
- return await self ._handle_structured_response (tool_calls , deps , conv_messages )
871
+ return await self ._handle_structured_response (tool_calls , run_context )
874
872
else :
875
873
raise exceptions .UnexpectedModelBehavior ('Received empty model response' )
876
874
877
875
async def _handle_text_response (
878
- self , text : str , deps : AgentDeps , conv_messages : list [ _messages . ModelMessage ]
876
+ self , text : str , run_context : RunContext [ AgentDeps ]
879
877
) -> tuple [_MarkFinalResult [ResultData ] | None , list [_messages .ModelRequestPart ]]:
880
878
"""Handle a plain text response from the model for non-streaming responses."""
881
879
if self ._allow_text_result :
882
880
result_data_input = cast (ResultData , text )
883
881
try :
884
- result_data = await self ._validate_result (result_data_input , deps , None , conv_messages )
882
+ result_data = await self ._validate_result (result_data_input , run_context , None )
885
883
except _result .ToolRetryError as e :
886
- self ._incr_result_retry ()
884
+ self ._incr_result_retry (run_context )
887
885
return None , [e .tool_retry ]
888
886
else :
889
887
return _MarkFinalResult (result_data , None ), []
890
888
else :
891
- self ._incr_result_retry ()
889
+ self ._incr_result_retry (run_context )
892
890
response = _messages .RetryPromptPart (
893
891
content = 'Plain text responses are not permitted, please call one of the functions instead.' ,
894
892
)
895
893
return None , [response ]
896
894
897
895
async def _handle_structured_response (
898
- self , tool_calls : list [_messages .ToolCallPart ], deps : AgentDeps , conv_messages : list [ _messages . ModelMessage ]
896
+ self , tool_calls : list [_messages .ToolCallPart ], run_context : RunContext [ AgentDeps ]
899
897
) -> tuple [_MarkFinalResult [ResultData ] | None , list [_messages .ModelRequestPart ]]:
900
898
"""Handle a structured response containing tool calls from the model for non-streaming responses."""
901
899
assert tool_calls , 'Expected at least one tool call'
@@ -909,26 +907,23 @@ async def _handle_structured_response(
909
907
call , result_tool = match
910
908
try :
911
909
result_data = result_tool .validate (call )
912
- result_data = await self ._validate_result (result_data , deps , call , conv_messages )
910
+ result_data = await self ._validate_result (result_data , run_context , call )
913
911
except _result .ToolRetryError as e :
914
- self ._incr_result_retry ()
912
+ self ._incr_result_retry (run_context )
915
913
parts .append (e .tool_retry )
916
914
else :
917
915
final_result = _MarkFinalResult (result_data , call .tool_name )
918
916
919
917
# Then build the other request parts based on end strategy
920
- parts += await self ._process_function_tools (
921
- tool_calls , final_result and final_result .tool_name , deps , conv_messages
922
- )
918
+ parts += await self ._process_function_tools (tool_calls , final_result and final_result .tool_name , run_context )
923
919
924
920
return final_result , parts
925
921
926
922
async def _process_function_tools (
927
923
self ,
928
924
tool_calls : list [_messages .ToolCallPart ],
929
925
result_tool_name : str | None ,
930
- deps : AgentDeps ,
931
- conv_messages : list [_messages .ModelMessage ],
926
+ run_context : RunContext [AgentDeps ],
932
927
) -> list [_messages .ModelRequestPart ]:
933
928
"""Process function (non-result) tool calls in parallel.
934
929
@@ -961,7 +956,7 @@ async def _process_function_tools(
961
956
)
962
957
)
963
958
else :
964
- tasks .append (asyncio .create_task (tool .run (deps , call , conv_messages ), name = call .tool_name ))
959
+ tasks .append (asyncio .create_task (tool .run (call , run_context ), name = call .tool_name ))
965
960
elif self ._result_schema is not None and call .tool_name in self ._result_schema .tools :
966
961
# if tool_name is in _result_schema, it means we found a result tool but an error occurred in
967
962
# validation, we don't add another part here
@@ -974,7 +969,7 @@ async def _process_function_tools(
974
969
)
975
970
)
976
971
else :
977
- parts .append (self ._unknown_tool (call .tool_name ))
972
+ parts .append (self ._unknown_tool (call .tool_name , run_context ))
978
973
979
974
# Run all tool tasks in parallel
980
975
if tasks :
@@ -986,8 +981,7 @@ async def _process_function_tools(
986
981
async def _handle_streamed_model_response (
987
982
self ,
988
983
model_response : models .EitherStreamedResponse ,
989
- deps : AgentDeps ,
990
- conv_messages : list [_messages .ModelMessage ],
984
+ run_context : RunContext [AgentDeps ],
991
985
) -> (
992
986
_MarkFinalResult [models .EitherStreamedResponse ]
993
987
| tuple [_messages .ModelResponse , list [_messages .ModelRequestPart ]]
@@ -1003,7 +997,7 @@ async def _handle_streamed_model_response(
1003
997
if self ._allow_text_result :
1004
998
return _MarkFinalResult (model_response , None )
1005
999
else :
1006
- self ._incr_result_retry ()
1000
+ self ._incr_result_retry (run_context )
1007
1001
response = _messages .RetryPromptPart (
1008
1002
content = 'Plain text responses are not permitted, please call one of the functions instead.' ,
1009
1003
)
@@ -1043,9 +1037,9 @@ async def _handle_streamed_model_response(
1043
1037
if isinstance (item , _messages .ToolCallPart ):
1044
1038
call = item
1045
1039
if tool := self ._function_tools .get (call .tool_name ):
1046
- tasks .append (asyncio .create_task (tool .run (deps , call , conv_messages ), name = call .tool_name ))
1040
+ tasks .append (asyncio .create_task (tool .run (call , run_context ), name = call .tool_name ))
1047
1041
else :
1048
- parts .append (self ._unknown_tool (call .tool_name ))
1042
+ parts .append (self ._unknown_tool (call .tool_name , run_context ))
1049
1043
1050
1044
with _logfire .span ('running {tools=}' , tools = [t .get_name () for t in tasks ]):
1051
1045
task_results : Sequence [_messages .ModelRequestPart ] = await asyncio .gather (* tasks )
@@ -1057,33 +1051,30 @@ async def _handle_streamed_model_response(
1057
1051
async def _validate_result (
1058
1052
self ,
1059
1053
result_data : ResultData ,
1060
- deps : AgentDeps ,
1054
+ run_context : RunContext [ AgentDeps ] ,
1061
1055
tool_call : _messages .ToolCallPart | None ,
1062
- conv_messages : list [_messages .ModelMessage ],
1063
1056
) -> ResultData :
1064
1057
for validator in self ._result_validators :
1065
- result_data = await validator .validate (
1066
- result_data , deps , self ._current_result_retry , tool_call , conv_messages
1067
- )
1058
+ result_data = await validator .validate (result_data , tool_call , run_context )
1068
1059
return result_data
1069
1060
1070
- def _incr_result_retry (self ) -> None :
1071
- self . _current_result_retry += 1
1072
- if self . _current_result_retry > self ._max_result_retries :
1061
+ def _incr_result_retry (self , run_context : RunContext [ AgentDeps ] ) -> None :
1062
+ run_context . retry += 1
1063
+ if run_context . retry > self ._max_result_retries :
1073
1064
raise exceptions .UnexpectedModelBehavior (
1074
1065
f'Exceeded maximum retries ({ self ._max_result_retries } ) for result validation'
1075
1066
)
1076
1067
1077
- async def _sys_parts (self , deps : AgentDeps ) -> list [_messages .ModelRequestPart ]:
1068
+ async def _sys_parts (self , run_context : RunContext [ AgentDeps ] ) -> list [_messages .ModelRequestPart ]:
1078
1069
"""Build the initial messages for the conversation."""
1079
1070
messages : list [_messages .ModelRequestPart ] = [_messages .SystemPromptPart (p ) for p in self ._system_prompts ]
1080
1071
for sys_prompt_runner in self ._system_prompt_functions :
1081
- prompt = await sys_prompt_runner .run (deps )
1072
+ prompt = await sys_prompt_runner .run (run_context )
1082
1073
messages .append (_messages .SystemPromptPart (prompt ))
1083
1074
return messages
1084
1075
1085
- def _unknown_tool (self , tool_name : str ) -> _messages .RetryPromptPart :
1086
- self ._incr_result_retry ()
1076
+ def _unknown_tool (self , tool_name : str , run_context : RunContext [ AgentDeps ] ) -> _messages .RetryPromptPart :
1077
+ self ._incr_result_retry (run_context )
1087
1078
names = list (self ._function_tools .keys ())
1088
1079
if self ._result_schema :
1089
1080
names .extend (self ._result_schema .tool_names ())
0 commit comments