Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
267 changes: 249 additions & 18 deletions jupyter_server_mcp/mcp_server.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,246 @@
"""Simple MCP server for registering Python functions as tools."""

import inspect
import json
import logging
from collections.abc import Callable
from inspect import iscoroutinefunction
from typing import Any
from functools import wraps
from inspect import iscoroutinefunction, signature
from typing import Any, Union, get_args, get_origin

from fastmcp import FastMCP
from traitlets import Bool, Int, Unicode
from traitlets import Int, Unicode
from traitlets.config.configurable import LoggingConfigurable

logger = logging.getLogger(__name__)


def _is_dict_compatible_annotation(annotation) -> bool:
"""Check if an annotation expects dict values that can be JSON-converted."""
# Direct dict annotation
if annotation is dict:
return True

# Union types: Optional[dict], Union[dict, None], dict | None
origin = get_origin(annotation)
if origin is Union or (
hasattr(annotation, "__class__")
and annotation.__class__.__name__ == "UnionType"
):
args = get_args(annotation)
return dict in args

# Typed dict annotations: Dict[K, V], dict[str, Any]
return bool(hasattr(annotation, "__origin__") and annotation.__origin__ is dict)


def _wrap_with_json_conversion(func: Callable) -> Callable:
"""
Wrapper that automatically converts JSON string arguments to dictionaries.

This addresses the common issue where MCP clients pass dictionary arguments
as JSON strings instead of structured objects. The wrapper inspects the
function signature and attempts JSON parsing for parameters annotated as
dict types when they are received as strings.

Additionally, this function modifies the type annotations to accept Union[dict, str]
for dict parameters to allow Pydantic validation to pass.

This conversion is always applied to all registered tools to ensure compatibility
with various MCP clients that may serialize dict parameters differently.

Args:
func: The function to wrap

Returns:
Wrapped function that handles JSON string conversion with modified annotations
"""
sig = signature(func)

def _should_convert_to_dict(annotation, value):
"""Check if a parameter should be converted from JSON string to dict."""
return isinstance(value, str) and _is_dict_compatible_annotation(annotation)

def _add_string_to_annotation(annotation):
"""Modify annotation to also accept strings for dict types."""
# Direct dict annotation -> dict | str
if annotation is dict:
return dict | str

# Union types: add str to existing union
origin = get_origin(annotation)
if origin is Union:
args = get_args(annotation)
if dict in args and str not in args:
return Union[(*tuple(args), str)]
return annotation

# New Python 3.10+ union syntax: dict | None
if (
hasattr(annotation, "__class__")
and annotation.__class__.__name__ == "UnionType"
):
args = get_args(annotation)
if dict in args and str not in args:
# Reconstruct the union with str added
new_args = (*tuple(args), str)
# Create new union type
result = new_args[0]
for arg in new_args[1:]:
result = result | arg
return result
return annotation

# Typed dict annotations -> annotation | str
if hasattr(annotation, "__origin__") and annotation.__origin__ is dict:
return annotation | str

return annotation

# Create new annotations that accept strings for dict parameters
new_annotations = {}
for param_name, param in sig.parameters.items():
if param.annotation != inspect.Parameter.empty:
new_annotations[param_name] = _add_string_to_annotation(param.annotation)
else:
new_annotations[param_name] = param.annotation

# Keep the return annotation unchanged
if hasattr(func, "__annotations__") and "return" in func.__annotations__:
new_annotations["return"] = func.__annotations__["return"]

if iscoroutinefunction(func):

@wraps(func)
async def async_wrapper(*args, **kwargs):
# Convert keyword arguments that should be dicts but are strings
converted_kwargs = {}
for param_name, param_value in kwargs.items():
if param_name in sig.parameters:
param = sig.parameters[param_name]
if _should_convert_to_dict(param.annotation, param_value):
try:
converted_kwargs[param_name] = json.loads(param_value)
logger.debug(
f"Converted JSON string to dict for parameter '{param_name}': {param_value}"
)
except json.JSONDecodeError:
# If it's not valid JSON, pass the string as-is
converted_kwargs[param_name] = param_value
else:
converted_kwargs[param_name] = param_value
else:
converted_kwargs[param_name] = param_value

return await func(*args, **converted_kwargs)

# Set the modified annotations on the wrapper
async_wrapper.__annotations__ = new_annotations
return async_wrapper

@wraps(func)
def sync_wrapper(*args, **kwargs):
# Convert keyword arguments that should be dicts but are strings
converted_kwargs = {}
for param_name, param_value in kwargs.items():
if param_name in sig.parameters:
param = sig.parameters[param_name]
if _should_convert_to_dict(param.annotation, param_value):
try:
converted_kwargs[param_name] = json.loads(param_value)
logger.debug(
f"Converted JSON string to dict for parameter '{param_name}': {param_value}"
)
except json.JSONDecodeError:
# If it's not valid JSON, pass the string as-is
converted_kwargs[param_name] = param_value
else:
converted_kwargs[param_name] = param_value
else:
converted_kwargs[param_name] = param_value

return func(*args, **converted_kwargs)

# Set the modified annotations on the wrapper
sync_wrapper.__annotations__ = new_annotations
return sync_wrapper


def _update_schema_for_json_args(func: Callable, tool) -> None:
"""
Modify the tool's JSON schema to accept strings for dict parameters.

This function updates the input schema to allow JSON strings in addition to objects for
parameters that are annotated as dict types, enabling MCP clients to pass JSON strings
that will be automatically converted to dicts.

This modification is always applied to ensure compatibility with various MCP clients.

Args:
func: The original function
tool: The FastMCP tool object
"""
try:
sig = signature(func)

# Get the MCP tool representation to modify its schema
mcp_tool_dict = tool.to_mcp_tool().model_dump()
input_schema = mcp_tool_dict.get("inputSchema", {})
properties = input_schema.get("properties", {})

# Check each parameter in the function signature
for param_name, param in sig.parameters.items():
if param_name in properties:
param_schema = properties[param_name]

# Check if this parameter should support JSON string conversion
annotation = param.annotation
should_support_string = _is_dict_compatible_annotation(annotation)

if should_support_string:
# Modify the schema to also accept strings
if "anyOf" in param_schema:
# For Optional[dict] - add string to the anyOf list
existing_schemas = param_schema["anyOf"]
# Check if string is already in the schema
has_string = any(
s.get("type") == "string" for s in existing_schemas
)
if not has_string:
existing_schemas.append(
{
"type": "string",
"description": "JSON string that will be parsed to object",
}
)
elif param_schema.get("type") == "object":
# For dict - convert to anyOf with object and string
original_schema = param_schema.copy()
properties[param_name] = {
"anyOf": [
original_schema,
{
"type": "string",
"description": "JSON string that will be parsed to object",
},
],
"title": param_schema.get("title", param_name.title()),
}
# Preserve default if it exists
if "default" in param_schema:
properties[param_name]["default"] = param_schema["default"]

# Update the tool's parameters with the modified schema
tool.parameters = input_schema

logger.debug(
f"Modified schema for tool '{tool.name}' to support JSON strings for dict parameters"
)

except Exception as e:
logger.warning(f"Could not modify schema for JSON string support: {e}")


class MCPServer(LoggingConfigurable):
"""Simple MCP server that allows registering Python functions as tools."""

Expand All @@ -28,10 +257,6 @@ class MCPServer(LoggingConfigurable):
default_value="localhost", help="Host for the MCP server to listen on"
).tag(config=True)

enable_debug_logging = Bool(
default_value=False, help="Enable debug logging for MCP operations"
).tag(config=True)

def __init__(self, **kwargs):
"""Initialize the MCP server.

Expand Down Expand Up @@ -65,14 +290,24 @@ def register_tool(
tool_description = description or func.__doc__ or f"Tool: {tool_name}"

self.log.info(f"Registering tool: {tool_name}")
if self.enable_debug_logging:
self.log.debug(
f"Tool details - Name: {tool_name}, "
f"Description: {tool_description}, Async: {iscoroutinefunction(func)}"
)
self.log.debug(
f"Tool details - Name: {tool_name}, "
f"Description: {tool_description}, Async: {iscoroutinefunction(func)}"
)

# Apply auto-conversion wrapper (always enabled)
registered_func = _wrap_with_json_conversion(func)
self.log.debug(f"Applied JSON argument auto-conversion wrapper to {tool_name}")

# Register with FastMCP
self.mcp.tool(func)
tool = self.mcp.tool(registered_func)

# Modify schema to support JSON strings for dict parameters
if tool:
_update_schema_for_json_args(func, tool)
self.log.debug(
f"Modified schema for tool '{tool_name}' to accept JSON strings for dict parameters"
)

# Keep track for listing
self._registered_tools[tool_name] = {
Expand Down Expand Up @@ -115,11 +350,7 @@ async def start_server(self, host: str | None = None):

self.log.info(f"Starting MCP server '{self.name}' on {server_host}:{self.port}")
self.log.info(f"Registered tools: {list(self._registered_tools.keys())}")

if self.enable_debug_logging:
self.log.debug(
f"Server configuration - Host: {server_host}, Port: {self.port}"
)
self.log.debug(f"Server configuration - Host: {server_host}, Port: {self.port}")

# Start FastMCP server with HTTP transport
await self.mcp.run_http_async(host=server_host, port=self.port)
Loading
Loading