Skip to content

Commit 9894a26

Browse files
authored
Make use of typing-inspection (#1019)
1 parent 4b05ddb commit 9894a26

File tree

9 files changed

+54
-58
lines changed

9 files changed

+54
-58
lines changed

pydantic_ai_slim/pydantic_ai/_pydantic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from __future__ import annotations as _annotations
77

88
from inspect import Parameter, signature
9-
from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast, get_origin
9+
from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast
1010

1111
from pydantic import ConfigDict
1212
from pydantic._internal import _decorators, _generate_schema, _typing_extra
@@ -15,6 +15,7 @@
1515
from pydantic.json_schema import GenerateJsonSchema
1616
from pydantic.plugin._schema_validator import create_schema_validator
1717
from pydantic_core import SchemaValidator, core_schema
18+
from typing_extensions import get_origin
1819

1920
from ._griffe import doc_descriptions
2021
from ._utils import check_object_json_schema, is_model_like
@@ -223,8 +224,7 @@ def _build_schema(
223224

224225

225226
def _is_call_ctx(annotation: Any) -> bool:
227+
"""Return whether the annotation is the `RunContext` class, parameterized or not."""
226228
from .tools import RunContext
227229

228-
return annotation is RunContext or (
229-
_typing_extra.is_generic_alias(annotation) and get_origin(annotation) is RunContext
230-
)
230+
return annotation is RunContext or get_origin(annotation) is RunContext

pydantic_ai_slim/pydantic_ai/_result.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from __future__ import annotations as _annotations
22

33
import inspect
4-
import sys
5-
import types
64
from collections.abc import Awaitable, Iterable, Iterator
75
from dataclasses import dataclass, field
8-
from typing import Any, Callable, Generic, Literal, Union, cast, get_args, get_origin
6+
from typing import Any, Callable, Generic, Literal, Union, cast
97

108
from pydantic import TypeAdapter, ValidationError
11-
from typing_extensions import TypeAliasType, TypedDict, TypeVar
9+
from typing_extensions import TypedDict, TypeVar, get_args, get_origin
10+
from typing_inspection import typing_objects
11+
from typing_inspection.introspection import is_union_origin
1212

1313
from . import _utils, messages as _messages
1414
from .exceptions import ModelRetry
@@ -248,23 +248,12 @@ def extract_str_from_union(response_type: Any) -> _utils.Option[Any]:
248248

249249

250250
def get_union_args(tp: Any) -> tuple[Any, ...]:
251-
"""Extract the arguments of a Union type if `response_type` is a union, otherwise return an empty union."""
252-
if isinstance(tp, TypeAliasType):
251+
"""Extract the arguments of a Union type if `response_type` is a union, otherwise return an empty tuple."""
252+
if typing_objects.is_typealiastype(tp):
253253
tp = tp.__value__
254254

255255
origin = get_origin(tp)
256-
if origin_is_union(origin):
256+
if is_union_origin(origin):
257257
return get_args(tp)
258258
else:
259259
return ()
260-
261-
262-
if sys.version_info < (3, 10):
263-
264-
def origin_is_union(tp: type[Any] | None) -> bool:
265-
return tp is Union
266-
267-
else:
268-
269-
def origin_is_union(tp: type[Any] | None) -> bool:
270-
return tp is Union or tp is types.UnionType

pydantic_ai_slim/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ dependencies = [
3939
"pydantic-graph==0.0.32",
4040
"exceptiongroup; python_version < '3.11'",
4141
"opentelemetry-api>=1.28.0",
42+
"typing-inspection>=0.4.0",
4243
]
4344

4445
[project.optional-dependencies]

pydantic_graph/pydantic_graph/_utils.py

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from __future__ import annotations as _annotations
22

33
import asyncio
4-
import sys
54
import types
65
from datetime import datetime, timezone
7-
from typing import Annotated, Any, TypeVar, Union, get_args, get_origin
6+
from typing import Any, TypeVar
87

9-
import typing_extensions
8+
from typing_extensions import TypeIs, get_args, get_origin
9+
from typing_inspection import typing_objects
10+
from typing_inspection.introspection import is_union_origin
1011

1112

1213
def get_event_loop():
@@ -19,13 +20,13 @@ def get_event_loop():
1920

2021

2122
def get_union_args(tp: Any) -> tuple[Any, ...]:
22-
"""Extract the arguments of a Union type if `response_type` is a union, otherwise return the original type."""
23+
"""Extract the arguments of a Union type if `response_type` is a union, otherwise return an empty tuple."""
2324
# similar to `pydantic_ai_slim/pydantic_ai/_result.py:get_union_args`
24-
if isinstance(tp, typing_extensions.TypeAliasType):
25+
if typing_objects.is_typealiastype(tp):
2526
tp = tp.__value__
2627

2728
origin = get_origin(tp)
28-
if origin_is_union(origin):
29+
if is_union_origin(origin):
2930
return get_args(tp)
3031
else:
3132
return (tp,)
@@ -38,35 +39,13 @@ def unpack_annotated(tp: Any) -> tuple[Any, list[Any]]:
3839
`(tp argument, ())` if not annotated, otherwise `(stripped type, annotations)`.
3940
"""
4041
origin = get_origin(tp)
41-
if origin is Annotated or origin is typing_extensions.Annotated:
42+
if typing_objects.is_annotated(origin):
4243
inner_tp, *args = get_args(tp)
4344
return inner_tp, args
4445
else:
4546
return tp, []
4647

4748

48-
def is_never(tp: Any) -> bool:
49-
"""Check if a type is `Never`."""
50-
if tp is typing_extensions.Never:
51-
return True
52-
elif typing_never := getattr(typing_extensions, 'Never', None):
53-
return tp is typing_never
54-
else:
55-
return False
56-
57-
58-
# same as `pydantic_ai_slim/pydantic_ai/_result.py:origin_is_union`
59-
if sys.version_info < (3, 10):
60-
61-
def origin_is_union(tp: type[Any] | None) -> bool:
62-
return tp is Union
63-
64-
else:
65-
66-
def origin_is_union(tp: type[Any] | None) -> bool:
67-
return tp is Union or tp is types.UnionType
68-
69-
7049
def comma_and(items: list[str]) -> str:
7150
"""Join with a comma and 'and' for the last item."""
7251
if len(items) == 1:
@@ -84,7 +63,11 @@ def get_parent_namespace(frame: types.FrameType | None) -> dict[str, Any] | None
8463
"""
8564
if frame is not None:
8665
if back := frame.f_back:
87-
if back.f_code.co_filename.endswith('/typing.py'):
66+
if back.f_globals.get('__name__') == 'typing':
67+
# If the class calling this function is generic, explicitly parameterizing the class
68+
# results in a `typing._GenericAlias` instance, which proxies instantiation calls to the
69+
# "real" class and thus adding an extra frame to the call. To avoid pulling anything
70+
# from the `typing` module, use the correct frame (the one before):
8871
return get_parent_namespace(back)
8972
else:
9073
return back.f_locals
@@ -107,5 +90,5 @@ class Unset:
10790
T = TypeVar('T')
10891

10992

110-
def is_set(t_or_unset: T | Unset) -> typing_extensions.TypeGuard[T]:
93+
def is_set(t_or_unset: T | Unset) -> TypeIs[T]:
11194
return t_or_unset is not UNSET

pydantic_graph/pydantic_graph/graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import pydantic
1414
import typing_extensions
1515
from logfire_api import LogfireSpan
16+
from typing_inspection import typing_objects
1617

1718
from . import _utils, exceptions, mermaid
1819
from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT
@@ -505,7 +506,7 @@ def _get_run_end_type(self) -> type[RunEndT]:
505506
args = typing_extensions.get_args(base)
506507
if len(args) == 3:
507508
t = args[2]
508-
if not _utils.is_never(t):
509+
if not typing_objects.is_never(t):
509510
return t
510511
# break the inner (bases) loop
511512
break

pydantic_graph/pydantic_graph/nodes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from abc import ABC, abstractmethod
44
from dataclasses import dataclass, is_dataclass
55
from functools import cache
6-
from typing import TYPE_CHECKING, Any, ClassVar, Generic, get_origin, get_type_hints
6+
from typing import TYPE_CHECKING, Any, ClassVar, Generic, get_type_hints
77

8-
from typing_extensions import Never, TypeVar
8+
from typing_extensions import Never, TypeVar, get_origin
99

1010
from . import _utils, exceptions
1111

pydantic_graph/pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@ classifiers = [
3131
"Topic :: Internet",
3232
]
3333
requires-python = ">=3.9"
34-
dependencies = ["httpx>=0.27", "logfire-api>=1.2.0", "pydantic>=2.10"]
34+
dependencies = [
35+
"httpx>=0.27",
36+
"logfire-api>=1.2.0",
37+
"pydantic>=2.10",
38+
"typing-inspection>=0.4.0",
39+
]
3540

3641
[project.urls]
3742
Homepage = "https://ai.pydantic.dev/graph/tree/main/pydantic_graph"

tests/models/test_model_names.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from collections.abc import Iterator
2-
from typing import Any, get_args
2+
from typing import Any
33

44
import pytest
5+
from typing_extensions import get_args
56

67
from pydantic_ai.models import KnownModelName
78

uv.lock

Lines changed: 16 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)