Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions src/mcp/server/fastmcp/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import inspect
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, get_type_hints

from pydantic import BaseModel, Field

Expand All @@ -15,6 +15,18 @@
from mcp.shared.context import LifespanContextT


def _is_context_type(annotation: type[Any]) -> bool:
from mcp.server.fastmcp import Context

if annotation is Context:
return True
if (
generic_metadata := getattr(annotation, "__pydantic_generic_metadata__", None)
) is not None:
return _is_context_type(generic_metadata["origin"])
return False


class Tool(BaseModel):
"""Internal tool registration info."""

Expand All @@ -40,8 +52,6 @@ def from_function(
context_kwarg: str | None = None,
) -> Tool:
"""Create a Tool from a function."""
from mcp.server.fastmcp import Context

func_name = name or fn.__name__

if func_name == "<lambda>":
Expand All @@ -51,9 +61,11 @@ def from_function(
is_async = inspect.iscoroutinefunction(fn)

if context_kwarg is None:
sig = inspect.signature(fn)
for param_name, param in sig.parameters.items():
if param.annotation is Context:
type_hints = get_type_hints(fn)
for param_name, param_type in type_hints.items():
if param_name == "return":
continue
if _is_context_type(param_type):
context_kwarg = param_name
break

Expand Down
27 changes: 27 additions & 0 deletions tests/server/fastmcp/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

if TYPE_CHECKING:
from mcp.server.fastmcp import Context
from mcp.server.session import ServerSession


class TestServer:
Expand Down Expand Up @@ -480,6 +481,32 @@ def tool_with_context(x: int, ctx: Context) -> str:
tool = mcp._tool_manager.add_tool(tool_with_context)
assert tool.context_kwarg == "ctx"

@pytest.mark.anyio
async def test_context_detection_forward_ref(self):
"""
Test that context parameters are properly detected with forward references.
"""
mcp = FastMCP()

def tool_with_context(x: int, ctx: "Context") -> str:
return f"Request {ctx.request_id}: {x}"

tool = mcp._tool_manager.add_tool(tool_with_context)
assert tool.context_kwarg == "ctx"

@pytest.mark.anyio
async def test_context_detection_generic_alias(self):
"""Test that context parameters are properly detected with generic alias."""
mcp = FastMCP()

class AppContext: ...

def tool_with_context(x: int, ctx: Context["ServerSession", AppContext]) -> str:
return f"Request {ctx.request_id}: {x}"

tool = mcp._tool_manager.add_tool(tool_with_context)
assert tool.context_kwarg == "ctx"

@pytest.mark.anyio
async def test_context_injection(self):
"""Test that context is properly injected into tool calls."""
Expand Down