Skip to content

Commit f07f7f3

Browse files
authored
support TypeAliasType unions (#26)
1 parent 0b626f8 commit f07f7f3

File tree

6 files changed

+131
-39
lines changed

6 files changed

+131
-39
lines changed

pydantic_ai/_pydantic.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,7 @@
2626
from .shared import AgentDeps
2727

2828

29-
__all__ = 'function_schema', 'LazyTypeAdapter', 'is_union'
30-
31-
32-
def is_union(tp: Any) -> bool:
33-
origin = get_origin(tp)
34-
return _typing_extra.origin_is_union(origin)
29+
__all__ = 'function_schema', 'LazyTypeAdapter'
3530

3631

3732
class FunctionSchema(TypedDict):

pydantic_ai/_result.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
from __future__ import annotations as _annotations
22

33
import inspect
4+
import sys
5+
import types
46
from collections.abc import Awaitable
57
from dataclasses import dataclass
6-
from typing import Any, Callable, Generic, Union, cast, get_args
8+
from typing import Any, Callable, Generic, Union, cast, get_args, get_origin
79

810
from pydantic import TypeAdapter, ValidationError
9-
from typing_extensions import Self, TypedDict
11+
from typing_extensions import Self, TypeAliasType, TypedDict
1012

11-
from . import _pydantic, _utils, messages
13+
from . import _utils, messages
1214
from .messages import LLMToolCalls, ToolCall
1315
from .shared import AgentDeps, CallContext, ModelRetry, ResultData
1416

@@ -106,7 +108,7 @@ def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[ResultDat
106108
)
107109

108110
tools: dict[str, ResultTool[ResultData]] = {}
109-
if args := union_args(response_type):
111+
if args := get_union_args(response_type):
110112
for arg in args:
111113
tool_name = union_tool_name(name, arg)
112114
tools[tool_name] = _build_tool(arg, tool_name, True)
@@ -204,10 +206,11 @@ def union_arg_name(union_arg: Any) -> str:
204206

205207
def extract_str_from_union(response_type: Any) -> _utils.Option[Any]:
206208
"""Extract the string type from a Union, return the remaining union or remaining type."""
207-
if _pydantic.is_union(response_type) and any(t is str for t in get_args(response_type)):
209+
union_args = get_union_args(response_type)
210+
if any(t is str for t in union_args):
208211
remain_args: list[Any] = []
209212
includes_str = False
210-
for arg in get_args(response_type):
213+
for arg in union_args:
211214
if arg is str:
212215
includes_str = True
213216
else:
@@ -219,9 +222,24 @@ def extract_str_from_union(response_type: Any) -> _utils.Option[Any]:
219222
return _utils.Some(Union[tuple(remain_args)])
220223

221224

222-
def union_args(response_type: Any) -> tuple[Any, ...]:
225+
def get_union_args(tp: Any) -> tuple[Any, ...]:
223226
"""Extract the arguments of a Union type if `response_type` is a union, otherwise return an empty union."""
224-
if _pydantic.is_union(response_type):
225-
return get_args(response_type)
227+
if isinstance(tp, TypeAliasType):
228+
tp = tp.__value__
229+
230+
origin = get_origin(tp)
231+
if origin_is_union(origin):
232+
return get_args(tp)
226233
else:
227234
return ()
235+
236+
237+
if sys.version_info < (3, 10):
238+
239+
def origin_is_union(tp: type[Any] | None) -> bool:
240+
return tp is Union
241+
242+
else:
243+
244+
def origin_is_union(tp: type[Any] | None) -> bool:
245+
return tp is Union or tp is types.UnionType

pydantic_ai/models/gemini.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -417,47 +417,48 @@ def __init__(self, schema: _utils.ObjectJsonSchema):
417417
self.defs = self.schema.pop('$defs', {})
418418

419419
def simplify(self) -> dict[str, Any]:
420-
self._simplify(self.schema, allow_ref=True)
420+
self._simplify(self.schema, refs_stack=())
421421
return self.schema
422422

423-
def _simplify(self, schema: dict[str, Any], allow_ref: bool) -> None:
423+
def _simplify(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
424424
schema.pop('title', None)
425425
schema.pop('default', None)
426426
if ref := schema.pop('$ref', None):
427-
if not allow_ref:
428-
raise shared.UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini')
429427
# noinspection PyTypeChecker
430428
key = re.sub(r'^#/\$defs/', '', ref)
429+
if key in refs_stack:
430+
raise shared.UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini')
431+
refs_stack += (key,)
431432
schema_def = self.defs[key]
432-
self._simplify(schema_def, allow_ref=False)
433+
self._simplify(schema_def, refs_stack)
433434
schema.update(schema_def)
434435
return
435436

436437
if any_of := schema.get('anyOf'):
437438
for schema in any_of:
438-
self._simplify(schema, allow_ref)
439+
self._simplify(schema, refs_stack)
439440

440441
type_ = schema.get('type')
441442

442443
if type_ == 'object':
443-
self._object(schema, allow_ref)
444+
self._object(schema, refs_stack)
444445
elif type_ == 'array':
445-
return self._array(schema, allow_ref)
446+
return self._array(schema, refs_stack)
446447

447-
def _object(self, schema: dict[str, Any], allow_ref: bool) -> None:
448+
def _object(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
448449
ad_props = schema.pop('additionalProperties', None)
449450
if ad_props:
450451
raise shared.UserError('Additional properties in JSON Schema are not supported by Gemini')
451452

452453
if properties := schema.get('properties'): # pragma: no branch
453454
for value in properties.values():
454-
self._simplify(value, allow_ref)
455+
self._simplify(value, refs_stack)
455456

456-
def _array(self, schema: dict[str, Any], allow_ref: bool) -> None:
457+
def _array(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
457458
if prefix_items := schema.get('prefixItems'):
458459
# TODO I think this not is supported by Gemini, maybe we should raise an error?
459460
for prefix_item in prefix_items:
460-
self._simplify(prefix_item, allow_ref)
461+
self._simplify(prefix_item, refs_stack)
461462

462463
if items_schema := schema.get('items'): # pragma: no branch
463-
self._simplify(items_schema, allow_ref)
464+
self._simplify(items_schema, refs_stack)

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ filterwarnings = [
132132

133133
# https://coverage.readthedocs.io/en/latest/config.html#run
134134
[tool.coverage.run]
135+
# required to avoid warnings about files created by create_module fixture
136+
include = ["pydantic_ai/**/*.py", "tests/**/*.py"]
135137
branch = true
136138

137139
# https://coverage.readthedocs.io/en/latest/config.html#report

tests/conftest.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
from __future__ import annotations as _annotations
22

3+
import importlib.util
34
import os
5+
import re
6+
import secrets
7+
import sys
48
from datetime import datetime
9+
from pathlib import Path
10+
from types import ModuleType
511
from typing import TYPE_CHECKING, Any, Callable
612

713
import httpx
814
import pytest
15+
from _pytest.assertion.rewrite import AssertionRewritingHook
916
from typing_extensions import TypeAlias
1017

1118
__all__ = 'IsNow', 'TestEnv'
@@ -78,3 +85,50 @@ def create_client(handler: Callable[[httpx.Request], httpx.Response]) -> httpx.A
7885

7986

8087
ClientWithHandler: TypeAlias = Callable[[Callable[[httpx.Request], httpx.Response]], httpx.AsyncClient]
88+
89+
90+
# pyright: reportUnknownMemberType=false, reportUnknownArgumentType=false
91+
@pytest.fixture
92+
def create_module(tmp_path: Path, request: pytest.FixtureRequest) -> Callable[[str], Any]:
93+
"""Taken from `pydantic/tests/conftest.py`, create module object, execute and return it."""
94+
95+
def run(
96+
source_code: str,
97+
rewrite_assertions: bool = True,
98+
module_name_prefix: str | None = None,
99+
) -> ModuleType:
100+
"""Create module object, execute and return it.
101+
102+
Can be used as a decorator of the function from the source code of which the module will be constructed.
103+
104+
Args:
105+
source_code: Python source code of the module
106+
rewrite_assertions: whether to rewrite assertions in module or not
107+
module_name_prefix: string prefix to use in the name of the module, does not affect the name of the file.
108+
109+
"""
110+
111+
# Max path length in Windows is 260. Leaving some buffer here
112+
max_name_len = 240 - len(str(tmp_path))
113+
# Windows does not allow these characters in paths. Linux bans slashes only.
114+
sanitized_name = re.sub('[' + re.escape('<>:"/\\|?*') + ']', '-', request.node.name)[:max_name_len]
115+
module_name = f'{sanitized_name}_{secrets.token_hex(5)}'
116+
path = tmp_path / f'{module_name}.py'
117+
path.write_text(source_code)
118+
filename = str(path)
119+
120+
if module_name_prefix:
121+
module_name = module_name_prefix + module_name
122+
123+
if rewrite_assertions:
124+
loader = AssertionRewritingHook(config=request.config)
125+
loader.mark_rewrite(module_name)
126+
else:
127+
loader = None
128+
129+
spec = importlib.util.spec_from_file_location(module_name, filename, loader=loader)
130+
sys.modules[module_name] = module = importlib.util.module_from_spec(spec) # pyright: ignore[reportArgumentType]
131+
spec.loader.exec_module(module) # pyright: ignore[reportOptionalMemberAccess]
132+
return module
133+
134+
return run

tests/test_agent.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
from datetime import timezone
23
from typing import Any, Callable, Union
34

@@ -255,23 +256,44 @@ def validate_result(ctx: CallContext[None], r: Any) -> Any:
255256
)
256257

257258

259+
# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false
260+
@pytest.mark.parametrize(
261+
'union_code',
262+
[
263+
pytest.param('ResultType = Union[Foo, Bar]'),
264+
pytest.param('ResultType = Foo | Bar', marks=pytest.mark.skipif(sys.version_info < (3, 10), reason='3.10+')),
265+
pytest.param(
266+
'ResultType: TypeAlias = Foo | Bar',
267+
marks=pytest.mark.skipif(sys.version_info < (3, 10), reason='Python 3.10+'),
268+
),
269+
pytest.param(
270+
'type ResultType = Foo | Bar', marks=pytest.mark.skipif(sys.version_info < (3, 12), reason='3.12+')
271+
),
272+
],
273+
)
274+
def test_response_multiple_return_tools(create_module: Callable[[str], Any], union_code: str):
275+
module_code = f'''
276+
from pydantic import BaseModel
277+
from typing import Union
278+
from typing_extensions import TypeAlias
279+
280+
class Foo(BaseModel):
281+
a: int
282+
b: str
283+
284+
258285
class Bar(BaseModel):
259286
"""This is a bar model."""
260287
261288
b: str
262289
290+
{union_code}
291+
'''
263292

264-
@pytest.mark.parametrize(
265-
'input_union_callable', [lambda: Union[Foo, Bar], lambda: Foo | Bar], ids=['Union[Foo, Bar]', 'Foo | Bar']
266-
)
267-
def test_response_multiple_return_tools(input_union_callable: Callable[[], Any]):
268-
try:
269-
union = input_union_callable()
270-
except TypeError:
271-
raise pytest.skip('Python version does not support `|` syntax for unions')
293+
mod = create_module(module_code)
272294

273295
m = TestModel()
274-
agent: Agent[None, Union[Foo, Bar]] = Agent(m, result_type=union)
296+
agent = Agent(m, result_type=mod.ResultType)
275297
got_tool_call_name = 'unset'
276298

277299
@agent.result_validator
@@ -281,7 +303,7 @@ def validate_result(ctx: CallContext[None], r: Any) -> Any:
281303
return r
282304

283305
result = agent.run_sync('Hello')
284-
assert result.response == Foo(a=0, b='a')
306+
assert result.response == mod.Foo(a=0, b='a')
285307
assert got_tool_call_name == snapshot('final_result_Foo')
286308

287309
assert m.agent_model_retrievers == snapshot({})
@@ -324,5 +346,5 @@ def validate_result(ctx: CallContext[None], r: Any) -> Any:
324346
)
325347

326348
result = agent.run_sync('Hello', model=TestModel(seed=1))
327-
assert result.response == Bar(b='b')
349+
assert result.response == mod.Bar(b='b')
328350
assert got_tool_call_name == snapshot('final_result_Bar')

0 commit comments

Comments
 (0)