diff --git a/src/mcp/server/fastmcp/tools/base.py b/src/mcp/server/fastmcp/tools/base.py index e137e8456..92a216f56 100644 --- a/src/mcp/server/fastmcp/tools/base.py +++ b/src/mcp/server/fastmcp/tools/base.py @@ -2,7 +2,7 @@ import inspect from collections.abc import Callable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, get_origin from pydantic import BaseModel, Field @@ -53,7 +53,9 @@ def from_function( if context_kwarg is None: sig = inspect.signature(fn) for param_name, param in sig.parameters.items(): - if param.annotation is Context: + if get_origin(param.annotation) is not None: + continue + if issubclass(param.annotation, Context): context_kwarg = param_name break diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py index d2067583e..8f52e3d85 100644 --- a/tests/server/fastmcp/test_tool_manager.py +++ b/tests/server/fastmcp/test_tool_manager.py @@ -4,8 +4,11 @@ import pytest from pydantic import BaseModel +from mcp.server.fastmcp import Context, FastMCP from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.tools import ToolManager +from mcp.server.session import ServerSessionT +from mcp.shared.context import LifespanContextT class TestAddTools: @@ -194,8 +197,6 @@ def concat_strs(vals: list[str] | str) -> str: @pytest.mark.anyio async def test_call_tool_with_complex_model(self): - from mcp.server.fastmcp import Context - class MyShrimpTank(BaseModel): class Shrimp(BaseModel): name: str @@ -223,8 +224,6 @@ def name_shrimp(tank: MyShrimpTank, ctx: Context) -> list[str]: class TestToolSchema: @pytest.mark.anyio async def test_context_arg_excluded_from_schema(self): - from mcp.server.fastmcp import Context - def something(a: int, ctx: Context) -> int: return a @@ -241,7 +240,6 @@ class TestContextHandling: def test_context_parameter_detection(self): """Test that context parameters are properly detected in Tool.from_function().""" - from mcp.server.fastmcp import Context def tool_with_context(x: int, ctx: Context) -> str: return str(x) @@ -256,10 +254,17 @@ def tool_without_context(x: int) -> str: tool = manager.add_tool(tool_without_context) assert tool.context_kwarg is None + def tool_with_parametrized_context( + x: int, ctx: Context[ServerSessionT, LifespanContextT] + ) -> str: + return str(x) + + tool = manager.add_tool(tool_with_parametrized_context) + assert tool.context_kwarg == "ctx" + @pytest.mark.anyio async def test_context_injection(self): """Test that context is properly injected during tool execution.""" - from mcp.server.fastmcp import Context, FastMCP def tool_with_context(x: int, ctx: Context) -> str: assert isinstance(ctx, Context) @@ -276,7 +281,6 @@ def tool_with_context(x: int, ctx: Context) -> str: @pytest.mark.anyio async def test_context_injection_async(self): """Test that context is properly injected in async tools.""" - from mcp.server.fastmcp import Context, FastMCP async def async_tool(x: int, ctx: Context) -> str: assert isinstance(ctx, Context) @@ -293,7 +297,6 @@ async def async_tool(x: int, ctx: Context) -> str: @pytest.mark.anyio async def test_context_optional(self): """Test that context is optional when calling tools.""" - from mcp.server.fastmcp import Context def tool_with_context(x: int, ctx: Context | None = None) -> str: return str(x) @@ -307,7 +310,6 @@ def tool_with_context(x: int, ctx: Context | None = None) -> str: @pytest.mark.anyio async def test_context_error_handling(self): """Test error handling when context injection fails.""" - from mcp.server.fastmcp import Context, FastMCP def tool_with_context(x: int, ctx: Context) -> str: raise ValueError("Test error")