File tree Expand file tree Collapse file tree 3 files changed +17
-10
lines changed
pydantic_ai_slim/pydantic_ai Expand file tree Collapse file tree 3 files changed +17
-10
lines changed Original file line number Diff line number Diff line change @@ -864,7 +864,7 @@ async def _call_tools(
864864 output_deferred_calls [deferred_calls_by_index [k ]].append (tool_calls [k ])
865865
866866
867- async def _call_tool ( # noqa: C901
867+ async def _call_tool (
868868 tool_manager : ToolManager [DepsT ],
869869 tool_call : _messages .ToolCallPart ,
870870 tool_call_result : DeferredToolResult | None ,
@@ -893,10 +893,6 @@ async def _call_tool( # noqa: C901
893893 tool_call_result .tool_name = tool_call .tool_name
894894 tool_call_result .tool_call_id = tool_call .tool_call_id
895895 raise ToolRetryError (tool_call_result )
896- elif isinstance (tool_call_result , _messages .ToolReturnPart ):
897- tool_call_result .tool_name = tool_call .tool_name
898- tool_call_result .tool_call_id = tool_call .tool_call_id
899- return tool_call_result , None
900896 else :
901897 tool_result = tool_call_result
902898 except ToolRetryError as e :
Original file line number Diff line number Diff line change @@ -459,14 +459,23 @@ def strip_markdown_fences(text: str) -> str:
459459 return text
460460
461461
462+ def _unwrap_annotated (tp : Any ) -> Any :
463+ origin = get_origin (tp )
464+ while typing_objects .is_annotated (origin ):
465+ tp = tp .__origin__
466+ origin = get_origin (tp )
467+ return tp
468+
469+
462470def get_union_args (tp : Any ) -> tuple [Any , ...]:
463471 """Extract the arguments of a Union type if `tp` is a union, otherwise return an empty tuple."""
464472 if typing_objects .is_typealiastype (tp ):
465473 tp = tp .__value__
466474
475+ tp = _unwrap_annotated (tp )
467476 origin = get_origin (tp )
468477 if is_union_origin (origin ):
469- return get_args (tp )
478+ return tuple ( _unwrap_annotated ( arg ) for arg in get_args (tp ) )
470479 else :
471480 return ()
472481
Original file line number Diff line number Diff line change @@ -628,10 +628,12 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
628628 approval = ToolDenied ()
629629 tool_call_results [tool_call_id ] = approval
630630
631- for tool_call_id , result in deferred_tool_results .calls .items ():
632- if not isinstance (result , _utils .get_union_args (DeferredToolCallResult )):
633- result = _messages .ToolReturn (result )
634- tool_call_results [tool_call_id ] = result
631+ if calls := deferred_tool_results .calls :
632+ call_result_types = _utils .get_union_args (DeferredToolCallResult )
633+ for tool_call_id , result in calls .items ():
634+ if not isinstance (result , call_result_types ):
635+ result = _messages .ToolReturn (result )
636+ tool_call_results [tool_call_id ] = result
635637
636638 graph_deps = _agent_graph .GraphAgentDeps [AgentDepsT , RunOutputDataT ](
637639 user_deps = deps ,
You can’t perform that action at this time.
0 commit comments