Skip to content

Commit 849576e

Browse files
Merge pull request #165 from jon-fox/feature/mcp-input-schema-fix-164
Feature/mcp input schema fix 164
2 parents d9f96eb + 183e769 commit 849576e

File tree

11 files changed

+277
-68
lines changed

11 files changed

+277
-68
lines changed

atomic-agents/atomic_agents/connectors/mcp/schema_transformer.py

Lines changed: 95 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Module for transforming JSON schemas to Pydantic models."""
22

33
import logging
4-
from typing import Any, Dict, List, Optional, Type, Tuple, Literal, cast
4+
from typing import Any, Dict, List, Optional, Type, Tuple, Literal, Union, cast
55

66
from pydantic import Field, create_model
77

@@ -24,33 +24,113 @@ class SchemaTransformer:
2424
"""Class for transforming JSON schemas to Pydantic models."""
2525

2626
@staticmethod
27-
def json_to_pydantic_field(prop_schema: Dict[str, Any], required: bool) -> Tuple[Type, Field]:
27+
def _resolve_ref(ref_path: str, root_schema: Dict[str, Any], model_cache: Dict[str, Type]) -> Type:
28+
"""Resolve a $ref to a Pydantic model."""
29+
# Extract ref name from path like "#/$defs/MyObject" or "#/definitions/ANode"
30+
ref_name = ref_path.split("/")[-1]
31+
32+
if ref_name in model_cache:
33+
return model_cache[ref_name]
34+
35+
# Look for the referenced schema in $defs or definitions
36+
defs = root_schema.get("$defs", root_schema.get("definitions", {}))
37+
if ref_name in defs:
38+
ref_schema = defs[ref_name]
39+
# Create model for the referenced schema
40+
model_name = ref_schema.get("title", ref_name)
41+
# Avoid infinite recursion by adding placeholder first
42+
model_cache[ref_name] = Any
43+
model = SchemaTransformer._create_nested_model(ref_schema, model_name, root_schema, model_cache)
44+
model_cache[ref_name] = model
45+
return model
46+
47+
logger.warning(f"Could not resolve $ref: {ref_path}")
48+
return Any
49+
50+
@staticmethod
51+
def _create_nested_model(
52+
schema: Dict[str, Any], model_name: str, root_schema: Dict[str, Any], model_cache: Dict[str, Type]
53+
) -> Type:
54+
"""Create a nested Pydantic model from a schema."""
55+
fields = {}
56+
required_fields = set(schema.get("required", []))
57+
properties = schema.get("properties", {})
58+
59+
for prop_name, prop_schema in properties.items():
60+
is_required = prop_name in required_fields
61+
fields[prop_name] = SchemaTransformer.json_to_pydantic_field(prop_schema, is_required, root_schema, model_cache)
62+
63+
return create_model(model_name, **fields)
64+
65+
@staticmethod
66+
def json_to_pydantic_field(
67+
prop_schema: Dict[str, Any],
68+
required: bool,
69+
root_schema: Optional[Dict[str, Any]] = None,
70+
model_cache: Optional[Dict[str, Type]] = None,
71+
) -> Tuple[Type, Field]:
2872
"""
2973
Convert a JSON schema property to a Pydantic field.
3074
3175
Args:
3276
prop_schema: JSON schema for the property
3377
required: Whether the field is required
78+
root_schema: Full root schema for resolving $refs
79+
model_cache: Cache for resolved models
3480
3581
Returns:
3682
Tuple of (type, Field)
3783
"""
38-
json_type = prop_schema.get("type")
84+
if root_schema is None:
85+
root_schema = {}
86+
if model_cache is None:
87+
model_cache = {}
88+
3989
description = prop_schema.get("description")
4090
default = prop_schema.get("default")
4191
python_type: Any = Any
4292

43-
if json_type in JSON_TYPE_MAP:
44-
python_type = JSON_TYPE_MAP[json_type]
45-
if json_type == "array":
46-
items_schema = prop_schema.get("items", {})
47-
item_type_str = items_schema.get("type")
48-
if item_type_str in JSON_TYPE_MAP:
49-
python_type = List[JSON_TYPE_MAP[item_type_str]]
93+
# Handle $ref
94+
if "$ref" in prop_schema:
95+
python_type = SchemaTransformer._resolve_ref(prop_schema["$ref"], root_schema, model_cache)
96+
# Handle oneOf/anyOf (unions)
97+
elif "oneOf" in prop_schema or "anyOf" in prop_schema:
98+
union_schemas = prop_schema.get("oneOf", prop_schema.get("anyOf", []))
99+
if union_schemas:
100+
union_types = []
101+
for union_schema in union_schemas:
102+
if "$ref" in union_schema:
103+
union_types.append(SchemaTransformer._resolve_ref(union_schema["$ref"], root_schema, model_cache))
104+
else:
105+
# Recursively resolve the union member
106+
member_type, _ = SchemaTransformer.json_to_pydantic_field(union_schema, True, root_schema, model_cache)
107+
union_types.append(member_type)
108+
109+
if len(union_types) == 1:
110+
python_type = union_types[0]
50111
else:
51-
python_type = List[Any]
52-
elif json_type == "object":
53-
python_type = Dict[str, Any]
112+
python_type = Union[tuple(union_types)]
113+
# Handle regular types
114+
else:
115+
json_type = prop_schema.get("type")
116+
if json_type in JSON_TYPE_MAP:
117+
python_type = JSON_TYPE_MAP[json_type]
118+
119+
if json_type == "array":
120+
items_schema = prop_schema.get("items", {})
121+
if "$ref" in items_schema:
122+
item_type = SchemaTransformer._resolve_ref(items_schema["$ref"], root_schema, model_cache)
123+
elif "oneOf" in items_schema or "anyOf" in items_schema:
124+
# Handle arrays of unions
125+
item_type, _ = SchemaTransformer.json_to_pydantic_field(items_schema, True, root_schema, model_cache)
126+
elif items_schema.get("type") in JSON_TYPE_MAP:
127+
item_type = JSON_TYPE_MAP[items_schema["type"]]
128+
else:
129+
item_type = Any
130+
python_type = List[item_type]
131+
132+
elif json_type == "object":
133+
python_type = Dict[str, Any]
54134

55135
field_kwargs = {"description": description}
56136
if required:
@@ -85,11 +165,12 @@ def create_model_from_schema(
85165
fields = {}
86166
required_fields = set(schema.get("required", []))
87167
properties = schema.get("properties")
168+
model_cache: Dict[str, Type] = {}
88169

89170
if properties:
90171
for prop_name, prop_schema in properties.items():
91172
is_required = prop_name in required_fields
92-
fields[prop_name] = SchemaTransformer.json_to_pydantic_field(prop_schema, is_required)
173+
fields[prop_name] = SchemaTransformer.json_to_pydantic_field(prop_schema, is_required, schema, model_cache)
93174
elif schema.get("type") == "object" and not properties:
94175
pass
95176
elif schema:

atomic-agents/tests/connectors/mcp/test_schema_transformer.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from typing import Any, Dict, List, Optional
2+
from typing import Any, Dict, List, Optional, Union
33

44
from atomic_agents import BaseIOSchema
55
from atomic_agents.connectors.mcp import SchemaTransformer
@@ -134,3 +134,76 @@ def test_tool_name_field(self):
134134
# Test that an invalid tool_name raises an error
135135
with pytest.raises(ValueError):
136136
model(tool_name="wrong_tool")
137+
138+
def test_union_type_oneof(self):
139+
"""Test oneOf creates Union types."""
140+
prop_schema = {"oneOf": [{"type": "string"}, {"type": "integer"}], "description": "A union field"}
141+
result = SchemaTransformer.json_to_pydantic_field(prop_schema, True)
142+
# Should create Union[str, int]
143+
assert result[0] == Union[str, int]
144+
assert result[1].description == "A union field"
145+
146+
def test_union_type_anyof(self):
147+
"""Test anyOf creates Union types."""
148+
prop_schema = {"anyOf": [{"type": "boolean"}, {"type": "number"}], "description": "Another union field"}
149+
result = SchemaTransformer.json_to_pydantic_field(prop_schema, True)
150+
# Should create Union[bool, float]
151+
assert result[0] == Union[bool, float]
152+
153+
def test_array_with_ref_items(self):
154+
"""Test arrays with $ref items are resolved."""
155+
root_schema = {
156+
"$defs": {"MyObject": {"type": "object", "properties": {"name": {"type": "string"}}, "title": "MyObject"}}
157+
}
158+
prop_schema = {"type": "array", "items": {"$ref": "#/$defs/MyObject"}, "description": "Array of MyObject"}
159+
result = SchemaTransformer.json_to_pydantic_field(prop_schema, True, root_schema)
160+
# Should be List[MyObject] not List[Any]
161+
assert hasattr(result[0], "__origin__") and result[0].__origin__ is list
162+
# The inner type should be the created model, not Any
163+
inner_type = result[0].__args__[0]
164+
assert inner_type != Any
165+
assert hasattr(inner_type, "model_fields")
166+
167+
def test_array_with_union_items(self):
168+
"""Test arrays with oneOf items."""
169+
prop_schema = {
170+
"type": "array",
171+
"items": {"oneOf": [{"type": "string"}, {"type": "integer"}]},
172+
"description": "Array of union items",
173+
}
174+
result = SchemaTransformer.json_to_pydantic_field(prop_schema, True)
175+
# Should be List[Union[str, int]]
176+
assert hasattr(result[0], "__origin__") and result[0].__origin__ is list
177+
inner_type = result[0].__args__[0]
178+
assert inner_type == Union[str, int]
179+
180+
def test_model_with_complex_types(self):
181+
"""Test create_model_from_schema with complex types."""
182+
schema = {
183+
"type": "object",
184+
"properties": {
185+
"expr": {"oneOf": [{"$ref": "#/$defs/ANode"}, {"$ref": "#/$defs/BNode"}], "description": "Expression node"},
186+
"objects": {"type": "array", "items": {"$ref": "#/$defs/MyObject"}, "description": "List of objects"},
187+
},
188+
"required": ["expr", "objects"],
189+
"$defs": {
190+
"ANode": {"type": "object", "properties": {"a_value": {"type": "string"}}, "title": "ANode"},
191+
"BNode": {"type": "object", "properties": {"b_value": {"type": "integer"}}, "title": "BNode"},
192+
"MyObject": {"type": "object", "properties": {"name": {"type": "string"}}, "title": "MyObject"},
193+
},
194+
}
195+
196+
model = SchemaTransformer.create_model_from_schema(schema, "ComplexModel", "complex_tool")
197+
198+
# Check that expr is a Union, not Any
199+
expr_field = model.model_fields["expr"]
200+
assert expr_field.annotation != Any
201+
# Should be Union[ANode, BNode]
202+
assert hasattr(expr_field.annotation, "__origin__") and expr_field.annotation.__origin__ is Union
203+
204+
# Check that objects is List[MyObject], not List[Any]
205+
objects_field = model.model_fields["objects"]
206+
assert objects_field.annotation != List[Any]
207+
assert hasattr(objects_field.annotation, "__origin__") and objects_field.annotation.__origin__ is list
208+
inner_type = objects_field.annotation.__args__[0]
209+
assert inner_type != Any

atomic-examples/mcp-agent/example-client/example_client/main_fastapi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class MCPConfig:
2020
"""Configuration for the MCP Agent system using HTTP Stream transport."""
2121

2222
mcp_server_url: str = "http://localhost:6969"
23-
openai_model: str = "gpt-4o"
23+
openai_model: str = "gpt-5-mini"
2424
openai_api_key: str = os.getenv("OPENAI_API_KEY") or ""
2525

2626
def __post_init__(self):

atomic-examples/mcp-agent/example-client/example_client/main_http.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class MCPConfig:
2323
"""Configuration for the MCP Agent system using HTTP Stream transport."""
2424

2525
mcp_server_url: str = "http://localhost:6969"
26-
openai_model: str = "gpt-4o"
26+
openai_model: str = "gpt-5-mini"
2727
openai_api_key: str = os.getenv("OPENAI_API_KEY")
2828

2929
def __post_init__(self):

atomic-examples/mcp-agent/example-client/example_client/main_sse.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ class MCPConfig:
2121

2222
mcp_server_url: str = "http://localhost:6969"
2323

24-
# NOTE: In contrast to other examples, we use gpt-4o and not gpt-4o-mini here.
25-
# In my tests, gpt-4o-mini was not smart enough to deal with multiple tools like that
24+
# NOTE: In contrast to other examples, we use gpt-5-mini and not gpt-4o-mini here.
25+
# In my tests, gpt-5-mini was not smart enough to deal with multiple tools like that
2626
# and at the moment MCP does not yet allow for adding sufficient metadata to
2727
# clarify tools even more and introduce more constraints.
28-
openai_model: str = "gpt-4o"
28+
openai_model: str = "gpt-5-mini"
2929
openai_api_key: str = os.getenv("OPENAI_API_KEY")
3030

3131
def __post_init__(self):

atomic-examples/mcp-agent/example-client/example_client/main_stdio.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222
class MCPConfig:
2323
"""Configuration for the MCP Agent system using STDIO transport."""
2424

25-
# NOTE: In contrast to other examples, we use gpt-4o and not gpt-4o-mini here.
26-
# In my tests, gpt-4o-mini was not smart enough to deal with multiple tools like that
25+
# NOTE: In contrast to other examples, we use gpt-5-mini and not gpt-5-mini here.
26+
# In my tests, gpt-5-mini was not smart enough to deal with multiple tools like that
2727
# and at the moment MCP does not yet allow for adding sufficient metadata to
2828
# clarify tools even more and introduce more constraints.
29-
openai_model: str = "gpt-4o"
29+
openai_model: str = "gpt-5-mini"
3030
openai_api_key: str = os.getenv("OPENAI_API_KEY")
3131

3232
# Command to run the STDIO server.

atomic-examples/mcp-agent/example-client/example_client/main_stdio_async.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222
class MCPConfig:
2323
"""Configuration for the MCP Agent system using STDIO transport."""
2424

25-
# NOTE: In contrast to other examples, we use gpt-4o and not gpt-4o-mini here.
26-
# In my tests, gpt-4o-mini was not smart enough to deal with multiple tools like that
25+
# NOTE: In contrast to other examples, we use gpt-5-mini and not gpt-4o-mini here.
26+
# In my tests, gpt-5-mini was not smart enough to deal with multiple tools like that
2727
# and at the moment MCP does not yet allow for adding sufficient metadata to
2828
# clarify tools even more and introduce more constraints.
29-
openai_model: str = "gpt-4o"
29+
openai_model: str = "gpt-5-mini"
3030
openai_api_key: str = os.getenv("OPENAI_API_KEY")
3131

3232
# Command to run the STDIO server.

atomic-examples/mcp-agent/example-mcp-server/example_mcp_server/server_http.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
SubtractNumbersTool,
1818
MultiplyNumbersTool,
1919
DivideNumbersTool,
20+
BatchCalculatorTool,
2021
)
2122

2223

@@ -27,6 +28,7 @@ def get_available_tools() -> List[Tool]:
2728
SubtractNumbersTool(),
2829
MultiplyNumbersTool(),
2930
DivideNumbersTool(),
31+
BatchCalculatorTool(),
3032
]
3133

3234

atomic-examples/mcp-agent/example-mcp-server/example_mcp_server/services/tool_service.py

Lines changed: 17 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import Dict, List, Any
44
from mcp.server.fastmcp import FastMCP
55
from example_mcp_server.interfaces.tool import Tool, ToolResponse, ToolContent
6-
from pydantic import Field
76

87

98
class ToolService:
@@ -43,8 +42,8 @@ async def execute_tool(self, tool_name: str, input_data: Dict[str, Any]) -> Tool
4342
"""
4443
tool = self.get_tool(tool_name)
4544

46-
# Convert input dictionary to the tool's input model
47-
input_model = tool.input_model(**input_data)
45+
# Use model_validate to handle complex nested objects properly
46+
input_model = tool.input_model.model_validate(input_data)
4847

4948
# Execute the tool with validated input
5049
return await tool.execute(input_model)
@@ -90,41 +89,19 @@ def _serialize_response(self, response: ToolResponse) -> Any:
9089
def register_mcp_handlers(self, mcp: FastMCP) -> None:
9190
"""Register all tools as MCP handlers."""
9291
for tool in self._tools.values():
93-
# Get the tool's schema
94-
schema = tool.input_model.model_json_schema()
95-
properties = schema.get("properties", {})
96-
97-
# Create a function signature that matches the schema with parameter descriptions
98-
params = []
99-
100-
for name, info in properties.items():
101-
type_hint = "str" # Default to str
102-
if info.get("type") == "integer":
103-
type_hint = "int"
104-
elif info.get("type") == "number":
105-
type_hint = "float"
106-
elif info.get("type") == "boolean":
107-
type_hint = "bool"
108-
109-
default = info.get("default", "...")
110-
description = info.get("description", "")
111-
112-
# Create parameter string for function definition with Field for descriptions
113-
if default == "...":
114-
params.append(f"{name}: {type_hint} = Field(description='{description}')")
115-
else:
116-
params.append(f"{name}: {type_hint} = Field(description='{description}', default={repr(default)})")
117-
118-
# Create the function definition
119-
fn_def = f"async def {tool.name}({', '.join(params)}):\n"
120-
fn_def += f' """{tool.description}"""\n'
121-
fn_def += " result = await self.execute_tool(tool.name, locals())\n"
122-
fn_def += " return self._serialize_response(result)"
123-
124-
# Create the function
125-
namespace = {"self": self, "tool": tool, "Field": Field}
126-
exec(fn_def, namespace)
127-
handler = namespace[tool.name]
128-
129-
# Register the handler
92+
# Create a handler that uses the tool's input model directly for schema generation
93+
def create_handler(tool_instance):
94+
# Use the actual Pydantic model as the function parameter
95+
# This ensures FastMCP gets the complete schema including nested objects
96+
async def handler(input_data: tool_instance.input_model):
97+
f'"""{tool_instance.description}"""'
98+
result = await self.execute_tool(tool_instance.name, input_data.model_dump())
99+
return self._serialize_response(result)
100+
101+
return handler
102+
103+
# Create the handler
104+
handler = create_handler(tool)
105+
106+
# Register with FastMCP - it should auto-detect the schema from the type annotation
130107
mcp.tool(name=tool.name, description=tool.description)(handler)

0 commit comments

Comments
 (0)