Skip to content

Commit 66290b6

Browse files
committed
fix: wrapped annotations handling in func_metadata
Use the original function's __globals__ for type hint resolution when dealing with wrapped functions. This ensures that any type hints defined in the original function's module are correctly resolved. This also includes adding common typing names for resiliency. Fixes #1391
1 parent c0f1657 commit 66290b6

File tree

4 files changed

+113
-3
lines changed

4 files changed

+113
-3
lines changed

src/mcp/server/fastmcp/utilities/func_metadata.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import functools
12
import inspect
23
import json
4+
import sys
5+
import typing as t
36
from collections.abc import Awaitable, Callable, Sequence
47
from itertools import chain
5-
from types import GenericAlias
8+
from types import GenericAlias, MethodType, ModuleType
69
from typing import Annotated, Any, ForwardRef, cast, get_args, get_origin, get_type_hints
710

811
import pydantic_core
@@ -468,10 +471,67 @@ def try_eval_type(value: Any, globalns: dict[str, Any], localns: dict[str, Any])
468471
return annotation
469472

470473

474+
def _resolve_callable_and_globalns(
475+
callable_obj: Callable[..., object] | functools.partial[Any] | MethodType,
476+
) -> tuple[Callable[..., object], dict[str, object]]:
477+
"""Unwrap a possibly-decorated/partial/method callable.
478+
479+
Returns the original function and robust global namespace for type-hint evaluation.
480+
481+
Args:
482+
callable_obj: A function, bound method, or functools.partial.
483+
484+
Returns:
485+
(unwrapped_callable, globalns)
486+
"""
487+
# Handle functools.partial
488+
base: object = callable_obj.func if isinstance(callable_obj, functools.partial) else callable_obj
489+
490+
# Follow __wrapped__ chain (requires @functools.wraps in decorators)
491+
unwrapped_obj: object = inspect.unwrap(cast(Callable[..., object], base))
492+
493+
# Handle bound methods
494+
if inspect.ismethod(unwrapped_obj):
495+
unwrapped_callable: Callable[..., object] = cast(MethodType, unwrapped_obj).__func__ # type: ignore[assignment]
496+
else:
497+
unwrapped_callable = cast(Callable[..., object], unwrapped_obj)
498+
499+
# Build globalns from module + function’s own globals
500+
globalns: dict[str, object] = {}
501+
module_name: str | None = getattr(unwrapped_callable, "__module__", None) # type: ignore[attr-defined]
502+
module_obj: ModuleType | None = sys.modules.get(module_name) if isinstance(module_name, str) else None
503+
if module_obj is not None:
504+
globalns.update(vars(module_obj))
505+
506+
func_globals: dict[str, object] | None = getattr(unwrapped_callable, "__globals__", None) # type: ignore[attr-defined]
507+
if isinstance(func_globals, dict):
508+
globalns.update(func_globals)
509+
510+
# Seed common typing names for resilience
511+
globalns.setdefault("typing", t)
512+
for name in (
513+
"Literal",
514+
"Annotated",
515+
"Optional",
516+
"Union",
517+
"Tuple",
518+
"Dict",
519+
"List",
520+
"Set",
521+
"Type",
522+
"Callable",
523+
):
524+
if name not in globalns and hasattr(t, name):
525+
globalns[name] = getattr(t, name) # type: ignore[index]
526+
527+
return unwrapped_callable, globalns
528+
529+
471530
def _get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
472531
"""Get function signature while evaluating forward references"""
473-
signature = inspect.signature(call)
474-
globalns = getattr(call, "__globals__", {})
532+
fn, globalns = _resolve_callable_and_globalns(call)
533+
signature = inspect.signature(fn)
534+
475535
typed_params = [
476536
inspect.Parameter(
477537
name=param.name,

tests/server/fastmcp/test_func_metadata.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from mcp.server.fastmcp.utilities.func_metadata import func_metadata
1616

17+
from .test_wrapped import wrapped_function
18+
1719

1820
class SomeInputModelA(BaseModel):
1921
pass
@@ -1094,3 +1096,20 @@ def func_with_reserved_json(
10941096
assert result["json"] == {"nested": "data"}
10951097
assert result["model_dump"] == [1, 2, 3]
10961098
assert result["normal"] == "plain string"
1099+
1100+
1101+
@pytest.mark.anyio
1102+
async def test_wrapped_annotations_func() -> None:
1103+
"""Test that func_metadata works with wrapped annotations functions."""
1104+
meta = func_metadata(wrapped_function)
1105+
1106+
result = await meta.call_fn_with_arg_validation(
1107+
wrapped_function,
1108+
fn_is_async=False,
1109+
arguments_to_validate={
1110+
"literal": "test",
1111+
},
1112+
arguments_to_pass_directly=None,
1113+
)
1114+
1115+
assert result == "test"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from collections.abc import Callable
2+
from functools import wraps
3+
from typing import TypeVar
4+
5+
from typing_extensions import ParamSpec
6+
7+
P = ParamSpec("P")
8+
R = TypeVar("R")
9+
10+
11+
def instrument(func: Callable[P, R]) -> Callable[P, R]:
12+
"""
13+
Example decorator that logs before/after the call
14+
while preserving the original function's type signature.
15+
"""
16+
17+
@wraps(func)
18+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
19+
return func(*args, **kwargs)
20+
21+
return wrapper
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from __future__ import annotations
2+
3+
from typing import Literal
4+
5+
from .test_instrument import instrument
6+
7+
8+
@instrument
9+
def wrapped_function(literal: Literal["test"] | None = None) -> Literal["test"] | None:
10+
return literal

0 commit comments

Comments
 (0)