@@ -454,8 +454,7 @@ async def _handle_tool_calls(
454
454
final_result : result .FinalResult [NodeRunEndT ] | None = None
455
455
parts : list [_messages .ModelRequestPart ] = []
456
456
if result_schema is not None :
457
- if match := result_schema .find_tool (tool_calls ):
458
- call , result_tool = match
457
+ for call , result_tool in result_schema .find_tool (tool_calls ):
459
458
try :
460
459
result_data = result_tool .validate (call )
461
460
result_data = await _validate_result (result_data , ctx , call )
@@ -465,12 +464,17 @@ async def _handle_tool_calls(
465
464
ctx .state .increment_retries (ctx .deps .max_result_retries )
466
465
parts .append (e .tool_retry )
467
466
else :
468
- final_result = result .FinalResult (result_data , call .tool_name )
467
+ final_result = result .FinalResult (result_data , call .tool_name , call .tool_call_id )
468
+ break
469
469
470
470
# Then build the other request parts based on end strategy
471
471
tool_responses : list [_messages .ModelRequestPart ] = self ._tool_responses
472
472
async for event in process_function_tools (
473
- tool_calls , final_result and final_result .tool_name , ctx , tool_responses
473
+ tool_calls ,
474
+ final_result and final_result .tool_name ,
475
+ final_result and final_result .tool_call_id ,
476
+ ctx ,
477
+ tool_responses ,
474
478
):
475
479
yield event
476
480
@@ -518,7 +522,7 @@ async def _handle_text_response(
518
522
return ModelRequestNode [DepsT , NodeRunEndT ](_messages .ModelRequest (parts = [e .tool_retry ]))
519
523
else :
520
524
# The following cast is safe because we know `str` is an allowed result type
521
- return self ._handle_final_result (ctx , result .FinalResult (result_data , tool_name = None ), [])
525
+ return self ._handle_final_result (ctx , result .FinalResult (result_data , None , None ), [])
522
526
else :
523
527
ctx .state .increment_retries (ctx .deps .max_result_retries )
524
528
return ModelRequestNode [DepsT , NodeRunEndT ](
@@ -547,6 +551,7 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
547
551
async def process_function_tools (
548
552
tool_calls : list [_messages .ToolCallPart ],
549
553
result_tool_name : str | None ,
554
+ result_tool_call_id : str | None ,
550
555
ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]],
551
556
output_parts : list [_messages .ModelRequestPart ],
552
557
) -> AsyncIterator [_messages .HandleResponseEvent ]:
@@ -566,7 +571,11 @@ async def process_function_tools(
566
571
calls_to_run : list [tuple [Tool [DepsT ], _messages .ToolCallPart ]] = []
567
572
call_index_to_event_id : dict [int , str ] = {}
568
573
for call in tool_calls :
569
- if call .tool_name == result_tool_name and not found_used_result_tool :
574
+ if (
575
+ call .tool_name == result_tool_name
576
+ and call .tool_call_id == result_tool_call_id
577
+ and not found_used_result_tool
578
+ ):
570
579
found_used_result_tool = True
571
580
output_parts .append (
572
581
_messages .ToolReturnPart (
@@ -593,9 +602,14 @@ async def process_function_tools(
593
602
# if tool_name is in _result_schema, it means we found a result tool but an error occurred in
594
603
# validation, we don't add another part here
595
604
if result_tool_name is not None :
605
+ if found_used_result_tool :
606
+ content = 'Result tool not used - a final result was already processed.'
607
+ else :
608
+ # TODO: Include information about the validation failure, and/or merge this with the ModelRetry part
609
+ content = 'Result tool not used - result failed validation.'
596
610
part = _messages .ToolReturnPart (
597
611
tool_name = call .tool_name ,
598
- content = 'Result tool not used - a final result was already processed.' ,
612
+ content = content ,
599
613
tool_call_id = call .tool_call_id ,
600
614
)
601
615
output_parts .append (part )
0 commit comments