|
| 1 | +import functools |
1 | 2 | import inspect
|
2 | 3 | import json
|
| 4 | +import sys |
| 5 | +import typing as t |
3 | 6 | from collections.abc import Awaitable, Callable, Sequence
|
4 | 7 | from itertools import chain
|
5 |
| -from types import GenericAlias |
| 8 | +from types import GenericAlias, MethodType, ModuleType |
6 | 9 | from typing import Annotated, Any, ForwardRef, cast, get_args, get_origin, get_type_hints
|
7 | 10 |
|
8 | 11 | import pydantic_core
|
@@ -468,10 +471,67 @@ def try_eval_type(value: Any, globalns: dict[str, Any], localns: dict[str, Any])
|
468 | 471 | return annotation
|
469 | 472 |
|
470 | 473 |
|
| 474 | +def _resolve_callable_and_globalns( |
| 475 | + callable_obj: Callable[..., object] | functools.partial[Any] | MethodType, |
| 476 | +) -> tuple[Callable[..., object], dict[str, object]]: |
| 477 | + """Unwrap a possibly-decorated/partial/method callable. |
| 478 | +
|
| 479 | + Returns the original function and robust global namespace for type-hint evaluation. |
| 480 | +
|
| 481 | + Args: |
| 482 | + callable_obj: A function, bound method, or functools.partial. |
| 483 | +
|
| 484 | + Returns: |
| 485 | + (unwrapped_callable, globalns) |
| 486 | + """ |
| 487 | + # Handle functools.partial |
| 488 | + base: object = callable_obj.func if isinstance(callable_obj, functools.partial) else callable_obj |
| 489 | + |
| 490 | + # Follow __wrapped__ chain (requires @functools.wraps in decorators) |
| 491 | + unwrapped_obj: object = inspect.unwrap(cast(Callable[..., object], base)) |
| 492 | + |
| 493 | + # Handle bound methods |
| 494 | + if inspect.ismethod(unwrapped_obj): |
| 495 | + unwrapped_callable: Callable[..., object] = cast(MethodType, unwrapped_obj).__func__ # type: ignore[assignment] |
| 496 | + else: |
| 497 | + unwrapped_callable = cast(Callable[..., object], unwrapped_obj) |
| 498 | + |
| 499 | + # Build globalns from module + function’s own globals |
| 500 | + globalns: dict[str, object] = {} |
| 501 | + module_name: str | None = getattr(unwrapped_callable, "__module__", None) # type: ignore[attr-defined] |
| 502 | + module_obj: ModuleType | None = sys.modules.get(module_name) if isinstance(module_name, str) else None |
| 503 | + if module_obj is not None: |
| 504 | + globalns.update(vars(module_obj)) |
| 505 | + |
| 506 | + func_globals: dict[str, object] | None = getattr(unwrapped_callable, "__globals__", None) # type: ignore[attr-defined] |
| 507 | + if isinstance(func_globals, dict): |
| 508 | + globalns.update(func_globals) |
| 509 | + |
| 510 | + # Seed common typing names for resilience |
| 511 | + globalns.setdefault("typing", t) |
| 512 | + for name in ( |
| 513 | + "Literal", |
| 514 | + "Annotated", |
| 515 | + "Optional", |
| 516 | + "Union", |
| 517 | + "Tuple", |
| 518 | + "Dict", |
| 519 | + "List", |
| 520 | + "Set", |
| 521 | + "Type", |
| 522 | + "Callable", |
| 523 | + ): |
| 524 | + if name not in globalns and hasattr(t, name): |
| 525 | + globalns[name] = getattr(t, name) # type: ignore[index] |
| 526 | + |
| 527 | + return unwrapped_callable, globalns |
| 528 | + |
| 529 | + |
471 | 530 | def _get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
472 | 531 | """Get function signature while evaluating forward references"""
|
473 |
| - signature = inspect.signature(call) |
474 |
| - globalns = getattr(call, "__globals__", {}) |
| 532 | + fn, globalns = _resolve_callable_and_globalns(call) |
| 533 | + signature = inspect.signature(fn) |
| 534 | + |
475 | 535 | typed_params = [
|
476 | 536 | inspect.Parameter(
|
477 | 537 | name=param.name,
|
|
0 commit comments