Skip to content

Commit 4603d5f

Browse files
committed
Unwrap Annotated in get_union_args util
1 parent 8eba30b commit 4603d5f

File tree

3 files changed

+17
-10
lines changed

3 files changed

+17
-10
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@ async def _call_tools(
864864
output_deferred_calls[deferred_calls_by_index[k]].append(tool_calls[k])
865865

866866

867-
async def _call_tool( # noqa: C901
867+
async def _call_tool(
868868
tool_manager: ToolManager[DepsT],
869869
tool_call: _messages.ToolCallPart,
870870
tool_call_result: DeferredToolResult | None,
@@ -893,10 +893,6 @@ async def _call_tool( # noqa: C901
893893
tool_call_result.tool_name = tool_call.tool_name
894894
tool_call_result.tool_call_id = tool_call.tool_call_id
895895
raise ToolRetryError(tool_call_result)
896-
elif isinstance(tool_call_result, _messages.ToolReturnPart):
897-
tool_call_result.tool_name = tool_call.tool_name
898-
tool_call_result.tool_call_id = tool_call.tool_call_id
899-
return tool_call_result, None
900896
else:
901897
tool_result = tool_call_result
902898
except ToolRetryError as e:

pydantic_ai_slim/pydantic_ai/_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,14 +459,23 @@ def strip_markdown_fences(text: str) -> str:
459459
return text
460460

461461

462+
def _unwrap_annotated(tp: Any) -> Any:
463+
origin = get_origin(tp)
464+
while typing_objects.is_annotated(origin):
465+
tp = tp.__origin__
466+
origin = get_origin(tp)
467+
return tp
468+
469+
462470
def get_union_args(tp: Any) -> tuple[Any, ...]:
463471
"""Extract the arguments of a Union type if `tp` is a union, otherwise return an empty tuple."""
464472
if typing_objects.is_typealiastype(tp):
465473
tp = tp.__value__
466474

475+
tp = _unwrap_annotated(tp)
467476
origin = get_origin(tp)
468477
if is_union_origin(origin):
469-
return get_args(tp)
478+
return tuple(_unwrap_annotated(arg) for arg in get_args(tp))
470479
else:
471480
return ()
472481

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -628,10 +628,12 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
628628
approval = ToolDenied()
629629
tool_call_results[tool_call_id] = approval
630630

631-
for tool_call_id, result in deferred_tool_results.calls.items():
632-
if not isinstance(result, _utils.get_union_args(DeferredToolCallResult)):
633-
result = _messages.ToolReturn(result)
634-
tool_call_results[tool_call_id] = result
631+
if calls := deferred_tool_results.calls:
632+
call_result_types = _utils.get_union_args(DeferredToolCallResult)
633+
for tool_call_id, result in calls.items():
634+
if not isinstance(result, call_result_types):
635+
result = _messages.ToolReturn(result)
636+
tool_call_results[tool_call_id] = result
635637

636638
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
637639
user_deps=deps,

0 commit comments

Comments
 (0)