Skip to content

Commit d25914f

Browse files
committed
fix: sevefal important fixes to MCP compiler
1 parent a15d7ef commit d25914f

File tree

3 files changed

+152
-21
lines changed

3 files changed

+152
-21
lines changed

nerve/tools/mcp/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111

1212
class Client:
13+
name: str = ""
14+
1315
def __init__(self, name: str, server: Configuration.MCPServer):
1416
self.name = name
1517
self.server = server

nerve/tools/mcp/compiler.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import typing as t
44

55
import jinja2
6+
from mcp import Tool
67
from pydantic import create_model
78

89
from nerve.models import Configuration
@@ -81,36 +82,44 @@ def _get_python_type(
8182
return None, type_mapping.get(schema_type, t.Any) # type: ignore
8283

8384

85+
async def create_function_body(client: Client, mcp_tool: Tool) -> tuple[str, dict[str, t.Any]]:
86+
typed_args = []
87+
client_name = f"nerve_mcp_{client.name}_client"
88+
type_defs = {client_name: client}
89+
90+
for name, arg_props in mcp_tool.inputSchema.get("properties", {}).items():
91+
args_def, arg_type = _get_python_type(name, arg_props)
92+
typed_args.append(
93+
{"name": name, "type": _stringify_type(arg_type), "description": arg_props.get("description", "")}
94+
)
95+
if args_def:
96+
type_defs.update(args_def)
97+
98+
# load the template from the same directory as this script
99+
template_path = os.path.join(os.path.dirname(__file__), "body.j2")
100+
with open(template_path) as f:
101+
template_content = f.read()
102+
103+
return (
104+
jinja2.Environment()
105+
.from_string(template_content)
106+
.render(client_name=client_name, tool=mcp_tool, arguments=typed_args),
107+
type_defs,
108+
)
109+
110+
84111
async def get_tools_from_mcp(name: str, server: Configuration.MCPServer) -> list[t.Callable[..., t.Any]]:
85112
# connect and list tools
86113
client = Client(name, server)
87114
mpc_tools = await client.tools()
88-
client_name = f"nerve_mcp_{name}_client"
89115
compiled_tools = []
90116

91117
for mcp_tool in mpc_tools:
92-
typed_args = []
93-
type_defs = {client_name: client}
94-
95-
for name, arg_props in mcp_tool.inputSchema.get("properties", {}).items():
96-
args_def, arg_type = _get_python_type(name, arg_props)
97-
typed_args.append({"name": name, "type": _stringify_type(arg_type)})
98-
if args_def:
99-
type_defs.update(args_def)
100-
101-
# load the template from the same directory as this script
102-
template_path = os.path.join(os.path.dirname(__file__), "body.j2")
103-
with open(template_path) as f:
104-
template_content = f.read()
105-
106-
func_body = (
107-
jinja2.Environment()
108-
.from_string(template_content)
109-
.render(client_name=client_name, tool=mcp_tool, arguments=typed_args)
110-
)
118+
func_body, type_defs = await create_function_body(client, mcp_tool)
119+
111120
# print(func_body)
112121
exec(func_body, type_defs)
113122

114-
tool_fn = wrap_tool_function(type_defs[mcp_tool.name]) # type: ignore
123+
tool_fn = wrap_tool_function(type_defs[mcp_tool.name])
115124
compiled_tools.append(tool_fn)
116125
return compiled_tools

nerve/tools/mcp/compilter_test.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import unittest
2+
from unittest.mock import MagicMock
3+
4+
from mcp import Tool
5+
6+
from nerve.tools.mcp.client import Client
7+
from nerve.tools.mcp.compiler import create_function_body
8+
9+
10+
class TestCreateFunctionBody(unittest.IsolatedAsyncioTestCase):
11+
async def test_create_function_body_with_string_argument(self) -> None:
12+
# Mock the client
13+
client = MagicMock(spec=Client)
14+
15+
# Create a mock tool with a simple string argument
16+
mock_tool = Tool(
17+
name="hello_world",
18+
description="A simple hello world function",
19+
inputSchema={
20+
"type": "object",
21+
"properties": {"name": {"type": "string", "description": "The name to greet"}},
22+
"required": ["name"],
23+
},
24+
)
25+
26+
func_body, type_defs = await create_function_body(client, mock_tool)
27+
28+
self.assertIn(
29+
'''
30+
async def hello_world(name: Annotated[str, "The name to greet"]) -> Any:
31+
"""A simple hello world function"""
32+
'''.strip(),
33+
func_body,
34+
)
35+
36+
async def test_create_function_body_with_string_and_int_arguments(self) -> None:
37+
# Mock the client
38+
client = MagicMock(spec=Client)
39+
40+
# Create a mock tool with string and int arguments
41+
mock_tool = Tool(
42+
name="calculate",
43+
description="A function that performs a calculation",
44+
inputSchema={
45+
"type": "object",
46+
"properties": {
47+
"operation": {"type": "string", "description": "The operation to perform"},
48+
"value": {"type": "integer", "description": "The value to calculate with"},
49+
},
50+
"required": ["operation", "value"],
51+
},
52+
)
53+
54+
func_body, type_defs = await create_function_body(client, mock_tool)
55+
56+
self.assertIn(
57+
'''
58+
async def calculate(operation: Annotated[str, "The operation to perform"], value: Annotated[int, "The value to calculate with"]) -> Any:
59+
"""A function that performs a calculation"""
60+
'''.strip(),
61+
func_body,
62+
)
63+
64+
async def test_create_function_body_with_complex_nested_arguments(self) -> None:
65+
# Mock the client
66+
client = MagicMock(spec=Client)
67+
68+
# Create a mock tool with complex nested arguments
69+
mock_tool = Tool(
70+
name="process_data",
71+
description="A function that processes complex data structures",
72+
inputSchema={
73+
"type": "object",
74+
"properties": {
75+
"user": {
76+
"type": "object",
77+
"properties": {
78+
"name": {"type": "string", "description": "User's name"},
79+
"age": {"type": "integer", "description": "User's age"},
80+
"preferences": {
81+
"type": "array",
82+
"items": {"type": "string"},
83+
"description": "User's preferences",
84+
},
85+
},
86+
"description": "User information",
87+
},
88+
"settings": {
89+
"type": "object",
90+
"properties": {
91+
"enabled": {"type": "boolean", "description": "Whether processing is enabled"},
92+
"options": {
93+
"type": "object",
94+
"properties": {
95+
"mode": {"type": "string", "description": "Processing mode"},
96+
"priority": {"type": "integer", "description": "Processing priority"},
97+
},
98+
"description": "Processing options",
99+
},
100+
},
101+
"description": "Processing settings",
102+
},
103+
},
104+
"required": ["user", "settings"],
105+
},
106+
)
107+
108+
func_body, type_defs = await create_function_body(client, mock_tool)
109+
110+
print()
111+
print(func_body)
112+
print()
113+
114+
self.assertIn(
115+
'''
116+
async def process_data(user: Annotated[user_0, "User information"], settings: Annotated[settings_1, "Processing settings"]) -> Any:
117+
"""A function that processes complex data structures"""
118+
'''.strip(),
119+
func_body,
120+
)

0 commit comments

Comments
 (0)