diff --git a/jupyter_server_mcp/mcp_server.py b/jupyter_server_mcp/mcp_server.py index 4f015e1..7b8e070 100644 --- a/jupyter_server_mcp/mcp_server.py +++ b/jupyter_server_mcp/mcp_server.py @@ -1,17 +1,246 @@ """Simple MCP server for registering Python functions as tools.""" +import inspect +import json import logging from collections.abc import Callable -from inspect import iscoroutinefunction -from typing import Any +from functools import wraps +from inspect import iscoroutinefunction, signature +from typing import Any, Union, get_args, get_origin from fastmcp import FastMCP -from traitlets import Bool, Int, Unicode +from traitlets import Int, Unicode from traitlets.config.configurable import LoggingConfigurable logger = logging.getLogger(__name__) +def _is_dict_compatible_annotation(annotation) -> bool: + """Check if an annotation expects dict values that can be JSON-converted.""" + # Direct dict annotation + if annotation is dict: + return True + + # Union types: Optional[dict], Union[dict, None], dict | None + origin = get_origin(annotation) + if origin is Union or ( + hasattr(annotation, "__class__") + and annotation.__class__.__name__ == "UnionType" + ): + args = get_args(annotation) + return dict in args + + # Typed dict annotations: Dict[K, V], dict[str, Any] + return bool(hasattr(annotation, "__origin__") and annotation.__origin__ is dict) + + +def _wrap_with_json_conversion(func: Callable) -> Callable: + """ + Wrapper that automatically converts JSON string arguments to dictionaries. + + This addresses the common issue where MCP clients pass dictionary arguments + as JSON strings instead of structured objects. The wrapper inspects the + function signature and attempts JSON parsing for parameters annotated as + dict types when they are received as strings. + + Additionally, this function modifies the type annotations to accept Union[dict, str] + for dict parameters to allow Pydantic validation to pass. + + This conversion is always applied to all registered tools to ensure compatibility + with various MCP clients that may serialize dict parameters differently. + + Args: + func: The function to wrap + + Returns: + Wrapped function that handles JSON string conversion with modified annotations + """ + sig = signature(func) + + def _should_convert_to_dict(annotation, value): + """Check if a parameter should be converted from JSON string to dict.""" + return isinstance(value, str) and _is_dict_compatible_annotation(annotation) + + def _add_string_to_annotation(annotation): + """Modify annotation to also accept strings for dict types.""" + # Direct dict annotation -> dict | str + if annotation is dict: + return dict | str + + # Union types: add str to existing union + origin = get_origin(annotation) + if origin is Union: + args = get_args(annotation) + if dict in args and str not in args: + return Union[(*tuple(args), str)] + return annotation + + # New Python 3.10+ union syntax: dict | None + if ( + hasattr(annotation, "__class__") + and annotation.__class__.__name__ == "UnionType" + ): + args = get_args(annotation) + if dict in args and str not in args: + # Reconstruct the union with str added + new_args = (*tuple(args), str) + # Create new union type + result = new_args[0] + for arg in new_args[1:]: + result = result | arg + return result + return annotation + + # Typed dict annotations -> annotation | str + if hasattr(annotation, "__origin__") and annotation.__origin__ is dict: + return annotation | str + + return annotation + + # Create new annotations that accept strings for dict parameters + new_annotations = {} + for param_name, param in sig.parameters.items(): + if param.annotation != inspect.Parameter.empty: + new_annotations[param_name] = _add_string_to_annotation(param.annotation) + else: + new_annotations[param_name] = param.annotation + + # Keep the return annotation unchanged + if hasattr(func, "__annotations__") and "return" in func.__annotations__: + new_annotations["return"] = func.__annotations__["return"] + + if iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(*args, **kwargs): + # Convert keyword arguments that should be dicts but are strings + converted_kwargs = {} + for param_name, param_value in kwargs.items(): + if param_name in sig.parameters: + param = sig.parameters[param_name] + if _should_convert_to_dict(param.annotation, param_value): + try: + converted_kwargs[param_name] = json.loads(param_value) + logger.debug( + f"Converted JSON string to dict for parameter '{param_name}': {param_value}" + ) + except json.JSONDecodeError: + # If it's not valid JSON, pass the string as-is + converted_kwargs[param_name] = param_value + else: + converted_kwargs[param_name] = param_value + else: + converted_kwargs[param_name] = param_value + + return await func(*args, **converted_kwargs) + + # Set the modified annotations on the wrapper + async_wrapper.__annotations__ = new_annotations + return async_wrapper + + @wraps(func) + def sync_wrapper(*args, **kwargs): + # Convert keyword arguments that should be dicts but are strings + converted_kwargs = {} + for param_name, param_value in kwargs.items(): + if param_name in sig.parameters: + param = sig.parameters[param_name] + if _should_convert_to_dict(param.annotation, param_value): + try: + converted_kwargs[param_name] = json.loads(param_value) + logger.debug( + f"Converted JSON string to dict for parameter '{param_name}': {param_value}" + ) + except json.JSONDecodeError: + # If it's not valid JSON, pass the string as-is + converted_kwargs[param_name] = param_value + else: + converted_kwargs[param_name] = param_value + else: + converted_kwargs[param_name] = param_value + + return func(*args, **converted_kwargs) + + # Set the modified annotations on the wrapper + sync_wrapper.__annotations__ = new_annotations + return sync_wrapper + + +def _update_schema_for_json_args(func: Callable, tool) -> None: + """ + Modify the tool's JSON schema to accept strings for dict parameters. + + This function updates the input schema to allow JSON strings in addition to objects for + parameters that are annotated as dict types, enabling MCP clients to pass JSON strings + that will be automatically converted to dicts. + + This modification is always applied to ensure compatibility with various MCP clients. + + Args: + func: The original function + tool: The FastMCP tool object + """ + try: + sig = signature(func) + + # Get the MCP tool representation to modify its schema + mcp_tool_dict = tool.to_mcp_tool().model_dump() + input_schema = mcp_tool_dict.get("inputSchema", {}) + properties = input_schema.get("properties", {}) + + # Check each parameter in the function signature + for param_name, param in sig.parameters.items(): + if param_name in properties: + param_schema = properties[param_name] + + # Check if this parameter should support JSON string conversion + annotation = param.annotation + should_support_string = _is_dict_compatible_annotation(annotation) + + if should_support_string: + # Modify the schema to also accept strings + if "anyOf" in param_schema: + # For Optional[dict] - add string to the anyOf list + existing_schemas = param_schema["anyOf"] + # Check if string is already in the schema + has_string = any( + s.get("type") == "string" for s in existing_schemas + ) + if not has_string: + existing_schemas.append( + { + "type": "string", + "description": "JSON string that will be parsed to object", + } + ) + elif param_schema.get("type") == "object": + # For dict - convert to anyOf with object and string + original_schema = param_schema.copy() + properties[param_name] = { + "anyOf": [ + original_schema, + { + "type": "string", + "description": "JSON string that will be parsed to object", + }, + ], + "title": param_schema.get("title", param_name.title()), + } + # Preserve default if it exists + if "default" in param_schema: + properties[param_name]["default"] = param_schema["default"] + + # Update the tool's parameters with the modified schema + tool.parameters = input_schema + + logger.debug( + f"Modified schema for tool '{tool.name}' to support JSON strings for dict parameters" + ) + + except Exception as e: + logger.warning(f"Could not modify schema for JSON string support: {e}") + + class MCPServer(LoggingConfigurable): """Simple MCP server that allows registering Python functions as tools.""" @@ -28,10 +257,6 @@ class MCPServer(LoggingConfigurable): default_value="localhost", help="Host for the MCP server to listen on" ).tag(config=True) - enable_debug_logging = Bool( - default_value=False, help="Enable debug logging for MCP operations" - ).tag(config=True) - def __init__(self, **kwargs): """Initialize the MCP server. @@ -65,14 +290,24 @@ def register_tool( tool_description = description or func.__doc__ or f"Tool: {tool_name}" self.log.info(f"Registering tool: {tool_name}") - if self.enable_debug_logging: - self.log.debug( - f"Tool details - Name: {tool_name}, " - f"Description: {tool_description}, Async: {iscoroutinefunction(func)}" - ) + self.log.debug( + f"Tool details - Name: {tool_name}, " + f"Description: {tool_description}, Async: {iscoroutinefunction(func)}" + ) + + # Apply auto-conversion wrapper (always enabled) + registered_func = _wrap_with_json_conversion(func) + self.log.debug(f"Applied JSON argument auto-conversion wrapper to {tool_name}") # Register with FastMCP - self.mcp.tool(func) + tool = self.mcp.tool(registered_func) + + # Modify schema to support JSON strings for dict parameters + if tool: + _update_schema_for_json_args(func, tool) + self.log.debug( + f"Modified schema for tool '{tool_name}' to accept JSON strings for dict parameters" + ) # Keep track for listing self._registered_tools[tool_name] = { @@ -115,11 +350,7 @@ async def start_server(self, host: str | None = None): self.log.info(f"Starting MCP server '{self.name}' on {server_host}:{self.port}") self.log.info(f"Registered tools: {list(self._registered_tools.keys())}") - - if self.enable_debug_logging: - self.log.debug( - f"Server configuration - Host: {server_host}, Port: {self.port}" - ) + self.log.debug(f"Server configuration - Host: {server_host}, Port: {self.port}") # Start FastMCP server with HTTP transport await self.mcp.run_http_async(host=server_host, port=self.port) diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index 204d67e..6dfa781 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -4,7 +4,7 @@ import pytest -from jupyter_server_mcp.mcp_server import MCPServer +from jupyter_server_mcp.mcp_server import MCPServer, _wrap_with_json_conversion def simple_function(x: int, y: int) -> int: @@ -37,7 +37,6 @@ def test_server_creation(self): assert server.name == "Jupyter MCP Server" assert server.port == 3001 assert server.host == "localhost" - assert server.enable_debug_logging is False assert server.mcp is not None assert len(server._registered_tools) == 0 @@ -213,3 +212,234 @@ def test_server_with_multiple_tools(self): assert server._registered_tools["simple_function"]["is_async"] is False assert server._registered_tools["async_function"]["is_async"] is True assert server._registered_tools["printer"]["is_async"] is False + + +class TestJSONArgumentConversion: + """Test JSON argument conversion functionality.""" + + def test_simple_dict_conversion(self): + """Test basic JSON string to dict conversion.""" + + def func_with_dict(data: dict) -> dict: + """Function that expects a dict.""" + return {"received": data, "type": type(data).__name__} + + wrapped_func = _wrap_with_json_conversion(func_with_dict) + + # Test with actual dict (should pass through) + result = wrapped_func(data={"key": "value"}) + assert result["received"] == {"key": "value"} + assert result["type"] == "dict" + + # Test with JSON string (should be converted) + result = wrapped_func(data='{"key": "value"}') + assert result["received"] == {"key": "value"} + assert result["type"] == "dict" + + def test_optional_dict_conversion(self): + """Test JSON conversion with Optional[dict] annotation.""" + + def func_with_optional_dict(data: dict | None = None) -> dict: + """Function with optional dict parameter.""" + return { + "received": data, + "type": type(data).__name__ if data else "NoneType", + } + + wrapped_func = _wrap_with_json_conversion(func_with_optional_dict) + + # Test with None (should pass through) + result = wrapped_func(data=None) + assert result["received"] is None + assert result["type"] == "NoneType" + + # Test with JSON string (should be converted) + result = wrapped_func(data='{"optional": true}') + assert result["received"] == {"optional": True} + assert result["type"] == "dict" + + def test_union_dict_conversion(self): + """Test JSON conversion with Union type annotations.""" + + def func_with_union_dict(data: dict | None) -> dict: + """Function with Union[dict, None] parameter.""" + return { + "received": data, + "type": type(data).__name__ if data else "NoneType", + } + + wrapped_func = _wrap_with_json_conversion(func_with_union_dict) + + # Test with JSON string (should be converted) + result = wrapped_func(data='{"union": "test"}') + assert result["received"] == {"union": "test"} + assert result["type"] == "dict" + + def test_typed_dict_conversion(self): + """Test JSON conversion with typed dict annotations.""" + + def func_with_typed_dict(config: dict[str, str]) -> dict: + """Function with Dict[str, str] annotation.""" + return {"received": config, "type": type(config).__name__} + + wrapped_func = _wrap_with_json_conversion(func_with_typed_dict) + + # Test with JSON string (should be converted) + result = wrapped_func(config='{"name": "test", "value": "data"}') + assert result["received"] == {"name": "test", "value": "data"} + assert result["type"] == "dict" + + def test_invalid_json_handling(self): + """Test handling of invalid JSON strings.""" + + def func_with_dict(data: dict) -> dict: + """Function that expects a dict.""" + return {"received": data, "type": type(data).__name__} + + wrapped_func = _wrap_with_json_conversion(func_with_dict) + + # Test with invalid JSON (should pass string as-is) + result = wrapped_func(data="invalid json {") + assert result["received"] == "invalid json {" + assert result["type"] == "str" + + # Test with empty string (should pass as-is) + result = wrapped_func(data="") + assert result["received"] == "" + assert result["type"] == "str" + + def test_non_dict_parameters_unchanged(self): + """Test that non-dict parameters are not affected.""" + + def mixed_func(name: str, count: int, data: dict) -> dict: + """Function with mixed parameter types.""" + return { + "name": name, + "name_type": type(name).__name__, + "count": count, + "count_type": type(count).__name__, + "data": data, + "data_type": type(data).__name__, + } + + wrapped_func = _wrap_with_json_conversion(mixed_func) + + # Only the dict parameter should be converted + result = wrapped_func(name="test", count=42, data='{"converted": true}') + + assert result["name"] == "test" + assert result["name_type"] == "str" + assert result["count"] == 42 + assert result["count_type"] == "int" + assert result["data"] == {"converted": True} + assert result["data_type"] == "dict" + + @pytest.mark.asyncio + async def test_async_function_conversion(self): + """Test JSON conversion with async functions.""" + + async def async_func_with_dict(config: dict) -> dict: + """Async function that expects a dict.""" + await asyncio.sleep(0.001) # Small delay + return {"async_result": config, "type": type(config).__name__} + + wrapped_func = _wrap_with_json_conversion(async_func_with_dict) + + # Test with JSON string (should be converted) + result = await wrapped_func(config='{"async": true, "value": 123}') + assert result["async_result"] == {"async": True, "value": 123} + assert result["type"] == "dict" + + def test_complex_nested_json(self): + """Test conversion of complex nested JSON structures.""" + + def func_with_nested_dict(data: dict) -> dict: + """Function that processes nested dict data.""" + return {"processed": data} + + wrapped_func = _wrap_with_json_conversion(func_with_nested_dict) + + complex_json = """{ + "users": [ + {"name": "Alice", "age": 30}, + {"name": "Bob", "age": 25} + ], + "metadata": { + "version": "1.0", + "created": "2024-01-01" + } + }""" + + result = wrapped_func(data=complex_json) + expected = { + "users": [{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}], + "metadata": {"version": "1.0", "created": "2024-01-01"}, + } + assert result["processed"] == expected + + def test_annotation_modification(self): + """Test that function annotations are properly modified.""" + + def original_func(data: dict) -> dict: + """Original function with dict annotation.""" + return data + + wrapped_func = _wrap_with_json_conversion(original_func) + + # Check that annotations were modified to accept strings + annotations = wrapped_func.__annotations__ + assert "data" in annotations + + # The annotation should now be dict | str (or Union equivalent) + data_annotation = annotations["data"] + # We can check this works by ensuring both dict and str are acceptable + assert hasattr(data_annotation, "__args__") or data_annotation == (dict | str) + + +class TestJSONSchemaModification: + """Test JSON schema modification for MCP tools.""" + + def test_schema_modification_applied(self): + """Test that schema modification is applied during tool registration.""" + server = MCPServer() + + def func_with_dict_param(config: dict) -> str: + """Function with dict parameter.""" + return f"Received config: {config}" + + # Register the function - schema should be automatically modified + server.register_tool(func_with_dict_param) + + # Verify the tool was registered + assert "func_with_dict_param" in server._registered_tools + tool_info = server._registered_tools["func_with_dict_param"] + assert tool_info["name"] == "func_with_dict_param" + + def test_multiple_dict_parameters(self): + """Test conversion with multiple dict parameters.""" + + def func_multiple_dicts(config: dict, metadata: dict, name: str) -> dict: + """Function with multiple dict parameters.""" + return { + "config": config, + "metadata": metadata, + "name": name, + "types": { + "config": type(config).__name__, + "metadata": type(metadata).__name__, + "name": type(name).__name__, + }, + } + + wrapped_func = _wrap_with_json_conversion(func_multiple_dicts) + + result = wrapped_func( + config='{"key1": "value1"}', metadata='{"version": 2}', name="test_function" + ) + + assert result["config"] == {"key1": "value1"} + assert result["metadata"] == {"version": 2} + assert result["name"] == "test_function" + assert result["types"]["config"] == "dict" + assert result["types"]["metadata"] == "dict" + assert result["types"]["name"] == "str"