diff --git a/src/mcp/server/fastmcp/tools/base.py b/src/mcp/server/fastmcp/tools/base.py index 21eb1841d..01fedcdc9 100644 --- a/src/mcp/server/fastmcp/tools/base.py +++ b/src/mcp/server/fastmcp/tools/base.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import functools import inspect from collections.abc import Callable from typing import TYPE_CHECKING, Any, get_origin @@ -53,7 +54,7 @@ def from_function( raise ValueError("You must provide a name for lambda functions") func_doc = description or fn.__doc__ or "" - is_async = inspect.iscoroutinefunction(fn) + is_async = _is_async_callable(fn) if context_kwarg is None: sig = inspect.signature(fn) @@ -98,3 +99,12 @@ async def run( ) except Exception as e: raise ToolError(f"Error executing tool {self.name}: {e}") from e + + +def _is_async_callable(obj: Any) -> bool: + while isinstance(obj, functools.partial): + obj = obj.func + + return inspect.iscoroutinefunction(obj) or ( + callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None)) + ) diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py index 015974eb0..203a7172b 100644 --- a/tests/server/fastmcp/test_tool_manager.py +++ b/tests/server/fastmcp/test_tool_manager.py @@ -102,6 +102,39 @@ def create_user(user: UserInput, flag: bool) -> dict: assert "age" in tool.parameters["$defs"]["UserInput"]["properties"] assert "flag" in tool.parameters["properties"] + def test_add_callable_object(self): + """Test registering a callable object.""" + + class MyTool: + def __init__(self): + self.__name__ = "MyTool" + + def __call__(self, x: int) -> int: + return x * 2 + + manager = ToolManager() + tool = manager.add_tool(MyTool()) + assert tool.name == "MyTool" + assert tool.is_async is False + assert tool.parameters["properties"]["x"]["type"] == "integer" + + @pytest.mark.anyio + async def test_add_async_callable_object(self): + """Test registering an async callable object.""" + + class MyAsyncTool: + def __init__(self): + self.__name__ = "MyAsyncTool" + + async def __call__(self, x: int) -> int: + return x * 2 + + manager = ToolManager() + tool = manager.add_tool(MyAsyncTool()) + assert tool.name == "MyAsyncTool" + assert tool.is_async is True + assert tool.parameters["properties"]["x"]["type"] == "integer" + def test_add_invalid_tool(self): manager = ToolManager() with pytest.raises(AttributeError): @@ -168,6 +201,34 @@ async def double(n: int) -> int: result = await manager.call_tool("double", {"n": 5}) assert result == 10 + @pytest.mark.anyio + async def test_call_object_tool(self): + class MyTool: + def __init__(self): + self.__name__ = "MyTool" + + def __call__(self, x: int) -> int: + return x * 2 + + manager = ToolManager() + tool = manager.add_tool(MyTool()) + result = await tool.run({"x": 5}) + assert result == 10 + + @pytest.mark.anyio + async def test_call_async_object_tool(self): + class MyAsyncTool: + def __init__(self): + self.__name__ = "MyAsyncTool" + + async def __call__(self, x: int) -> int: + return x * 2 + + manager = ToolManager() + tool = manager.add_tool(MyAsyncTool()) + result = await tool.run({"x": 5}) + assert result == 10 + @pytest.mark.anyio async def test_call_tool_with_default_args(self): def add(a: int, b: int = 1) -> int: