@@ -562,6 +562,8 @@ async def on_complete():
562
562
parts = await self ._process_function_tools (
563
563
tool_calls , result_tool_name , run_context , result_schema
564
564
)
565
+ if any (isinstance (part , _messages .RetryPromptPart ) for part in parts ):
566
+ self ._incr_result_retry (run_context )
565
567
if parts :
566
568
messages .append (_messages .ModelRequest (parts ))
567
569
run_span .set_attribute ('all_messages' , messages )
@@ -1147,7 +1149,6 @@ async def _handle_structured_response(
1147
1149
result_data = result_tool .validate (call )
1148
1150
result_data = await self ._validate_result (result_data , run_context , call )
1149
1151
except _result .ToolRetryError as e :
1150
- self ._incr_result_retry (run_context )
1151
1152
parts .append (e .tool_retry )
1152
1153
else :
1153
1154
final_result = _MarkFinalResult (result_data , call .tool_name )
@@ -1157,6 +1158,9 @@ async def _handle_structured_response(
1157
1158
tool_calls , final_result and final_result .tool_name , run_context , result_schema
1158
1159
)
1159
1160
1161
+ if any (isinstance (part , _messages .RetryPromptPart ) for part in parts ):
1162
+ self ._incr_result_retry (run_context )
1163
+
1160
1164
return final_result , parts
1161
1165
1162
1166
async def _process_function_tools (
@@ -1210,7 +1214,7 @@ async def _process_function_tools(
1210
1214
)
1211
1215
)
1212
1216
else :
1213
- parts .append (self ._unknown_tool (call .tool_name , run_context , result_schema ))
1217
+ parts .append (self ._unknown_tool (call .tool_name , result_schema ))
1214
1218
1215
1219
# Run all tool tasks in parallel
1216
1220
if tasks :
@@ -1257,7 +1261,7 @@ async def _handle_streamed_response(
1257
1261
if tool := self ._function_tools .get (p .tool_name ):
1258
1262
tasks .append (asyncio .create_task (tool .run (p , run_context ), name = p .tool_name ))
1259
1263
else :
1260
- parts .append (self ._unknown_tool (p .tool_name , run_context , result_schema ))
1264
+ parts .append (self ._unknown_tool (p .tool_name , result_schema ))
1261
1265
1262
1266
if received_text and not tasks and not parts :
1263
1267
# Can only get here if self._allow_text_result returns `False` for the provided result_schema
@@ -1270,6 +1274,10 @@ async def _handle_streamed_response(
1270
1274
with _logfire .span ('running {tools=}' , tools = [t .get_name () for t in tasks ]):
1271
1275
task_results : Sequence [_messages .ModelRequestPart ] = await asyncio .gather (* tasks )
1272
1276
parts .extend (task_results )
1277
+
1278
+ if any (isinstance (part , _messages .RetryPromptPart ) for part in parts ):
1279
+ self ._incr_result_retry (run_context )
1280
+
1273
1281
return model_response , parts
1274
1282
1275
1283
async def _validate_result (
@@ -1307,10 +1315,8 @@ async def _sys_parts(self, run_context: RunContext[AgentDepsT]) -> list[_message
1307
1315
def _unknown_tool (
1308
1316
self ,
1309
1317
tool_name : str ,
1310
- run_context : RunContext [AgentDepsT ],
1311
1318
result_schema : _result .ResultSchema [RunResultData ] | None ,
1312
1319
) -> _messages .RetryPromptPart :
1313
- self ._incr_result_retry (run_context )
1314
1320
names = list (self ._function_tools .keys ())
1315
1321
if result_schema :
1316
1322
names .extend (result_schema .tool_names ())
0 commit comments