Skip to content

Commit 84a60a2

Browse files
committed
feat: replace global _mcp_servers with toolsets
1 parent 5dba311 commit 84a60a2

File tree

5 files changed

+128
-62
lines changed

5 files changed

+128
-62
lines changed

mcp-run-python/src/main.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,19 @@ The tools are injected into the global namespace automatically - no discovery fu
106106
.array(z.string())
107107
.optional()
108108
.describe('List of available tools for injection (enables tool injection when provided)'),
109+
tool_name_mapping: z
110+
.record(z.string())
111+
.optional()
112+
.describe('Mapping of python_name -> original_mcp_name for tool name conversion'),
109113
},
110114
async ({
111115
python_code,
112116
tools = [],
117+
tool_name_mapping = {},
113118
}: {
114119
python_code: string
115120
tools?: string[]
121+
tool_name_mapping?: Record<string, string>
116122
}) => {
117123
const logPromises: Promise<void>[] = []
118124

@@ -179,6 +185,7 @@ The tools are injected into the global namespace automatically - no discovery fu
179185
{
180186
enableToolInjection: true,
181187
availableTools: tools,
188+
toolNameMapping: tool_name_mapping,
182189
timeoutSeconds: 30,
183190
elicitationCallback,
184191
} as ToolInjectionConfig,

mcp-run-python/src/runCode.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ export interface CodeFile {
1313
export interface ToolInjectionConfig {
1414
enableToolInjection: boolean
1515
availableTools: string[]
16+
toolNameMapping?: Record<string, string> // python_name -> original_mcp_name
1617
timeoutSeconds: number
1718
// deno-lint-ignore no-explicit-any
1819
elicitationCallback?: (request: any) => Promise<any>
@@ -157,6 +158,7 @@ function injectToolFunctions(
157158
globals,
158159
config.availableTools,
159160
tool_callback,
161+
config.toolNameMapping,
160162
)
161163

162164
log('info', `Tool injection complete. Available tools: ${config.availableTools.join(', ')}`)

mcp-run-python/src/tool_injection.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,38 +58,52 @@ def _handle_tool_callback_result(result: Any, tool_name: str) -> Any:
5858
return result
5959

6060

61-
def _create_tool_function(tool_name: str, tool_callback: Callable[[Any], Any]) -> Callable[..., Any]:
61+
def _create_tool_function(
62+
tool_name: str, tool_callback: Callable[[Any], Any], globals_dict: dict[str, Any]
63+
) -> Callable[..., Any]:
6264
"""Create a tool function that can be called from Python."""
6365

6466
def tool_function(*args: Any, **kwargs: Any) -> Any:
6567
"""Synchronous tool function that handles the async callback properly."""
66-
# Note: tool_callback is guaranteed to be not None due to check in inject_tool_functions
6768

68-
elicitation_request = _create_elicitation_request(tool_name, args, kwargs)
69+
# Get the actual MCP tool name from the stored mapping
70+
tool_mapping = globals_dict.get('__tool_name_mapping__', {})
71+
actual_tool_name = tool_mapping.get(tool_name, tool_name)
72+
73+
elicitation_request = _create_elicitation_request(actual_tool_name, args, kwargs)
6974

7075
try:
7176
result = tool_callback(elicitation_request)
72-
return _handle_tool_callback_result(result, tool_name)
77+
return _handle_tool_callback_result(result, actual_tool_name)
7378
except Exception as e:
74-
raise Exception(f'Tool {tool_name} failed: {str(e)}')
79+
raise Exception(f'Tool {actual_tool_name} failed: {str(e)}')
7580

7681
return tool_function
7782

7883

7984
def inject_tool_functions(
80-
globals_dict: dict[str, Any], available_tools: list[str], tool_callback: Callable[[Any], Any] | None = None
85+
globals_dict: dict[str, Any],
86+
available_tools: list[str],
87+
tool_callback: Callable[[Any], Any] | None = None,
88+
tool_name_mapping: dict[str, str] | None = None,
8189
) -> None:
8290
"""Inject tool functions into the global namespace.
8391
8492
Args:
8593
globals_dict: Global namespace to inject tools into
86-
available_tools: List of available tool names
94+
available_tools: List of available tool names (should be Python-valid identifiers)
8795
tool_callback: Optional callback for tool execution
96+
tool_name_mapping: Optional mapping of python_name -> original_mcp_name
8897
"""
8998
if not available_tools:
9099
return
91100

92-
# Inject tool functions into globals
101+
# Store the tool name mapping globally for elicitation callback to use
102+
if tool_name_mapping:
103+
globals_dict['__tool_name_mapping__'] = tool_name_mapping
104+
105+
# Inject tool functions into globals using Python-valid names
93106
for tool_name in available_tools:
94107
if tool_callback is not None:
95-
globals_dict[tool_name] = _create_tool_function(tool_name, tool_callback)
108+
# tool_name should already be a valid Python identifier from agent.py
109+
globals_dict[tool_name] = _create_tool_function(tool_name, tool_callback, globals_dict)

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 73 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from types import FrameType
1313
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final, overload
1414

15+
from mcp import types as mcp_types
1516
from opentelemetry.trace import NoOpTracer, use_span
1617
from pydantic.json_schema import GenerateJsonSchema
1718
from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated
@@ -70,7 +71,6 @@
7071
from fasta2a.broker import Broker
7172
from fasta2a.schema import AgentProvider, Skill
7273
from fasta2a.storage import Storage
73-
from mcp import types as mcp_types
7474
from starlette.middleware import Middleware
7575
from starlette.routing import BaseRoute, Route
7676
from starlette.types import ExceptionHandler, Lifespan
@@ -1795,6 +1795,20 @@ async def __aenter__(self) -> Self:
17951795
async with self._enter_lock:
17961796
if self._entered_count == 0:
17971797
self._exit_stack = AsyncExitStack()
1798+
1799+
for toolset in self._user_toolsets:
1800+
if isinstance(toolset, MCPServer):
1801+
if (
1802+
hasattr(toolset, 'allow_elicitation')
1803+
and toolset.allow_elicitation
1804+
and toolset.elicitation_callback is None
1805+
):
1806+
toolset.elicitation_callback = self._create_elicitation_callback()
1807+
1808+
# Also setup auto-tool-injection for run_python_code if not already set
1809+
if toolset.process_tool_call is None:
1810+
toolset.process_tool_call = self._create_auto_tool_injection_callback()
1811+
17981812
toolset = self._get_toolset()
17991813
await self._exit_stack.enter_async_context(toolset)
18001814
self._entered_count += 1
@@ -1847,23 +1861,36 @@ async def elicitation_callback(context: Any, params: Any) -> Any:
18471861

18481862
return mcp_types.ElicitResult(action='accept', content={'result': str(result)})
18491863

1850-
# Try MCP tools with name mapping
1851-
actual_tool_name = tool_name.replace('_', '-')
1852-
1864+
# Find the MCP server that has this tool
1865+
target_server = None
18531866
for toolset in self._user_toolsets:
18541867
if not isinstance(toolset, MCPServer):
18551868
continue
1856-
mcp_server = toolset
1857-
if 'mcp-run-python' in str(mcp_server):
1869+
if 'mcp-run-python' in str(toolset):
18581870
continue
18591871

1872+
# Check if this server has the tool
18601873
try:
1861-
result = await mcp_server.direct_call_tool(actual_tool_name, tool_arguments)
1862-
return mcp_types.ElicitResult(action='accept', content={'result': str(result)})
1874+
server_tools = await toolset.list_tools()
1875+
for tool_def in server_tools:
1876+
if tool_def.name == tool_name:
1877+
target_server = toolset
1878+
break
1879+
if target_server:
1880+
break
18631881
except Exception:
18641882
continue
18651883

1866-
return mcp_types.ErrorData(code=mcp_types.INVALID_PARAMS, message=f'Tool {tool_name} not found')
1884+
if target_server:
1885+
try:
1886+
result = await target_server.direct_call_tool(tool_name, tool_arguments)
1887+
return mcp_types.ElicitResult(action='accept', content={'result': str(result)})
1888+
except Exception as e:
1889+
return mcp_types.ErrorData(
1890+
code=mcp_types.INTERNAL_ERROR, message=f'Tool execution failed: {str(e)}'
1891+
)
1892+
else:
1893+
return mcp_types.ErrorData(code=mcp_types.INVALID_PARAMS, message=f'Tool {tool_name} not found')
18671894

18681895
except Exception as e:
18691896
return mcp_types.ErrorData(code=mcp_types.INTERNAL_ERROR, message=f'Tool execution failed: {str(e)}')
@@ -1881,30 +1908,37 @@ async def auto_inject_tools_callback(
18811908
) -> Any:
18821909
"""Auto-inject available tools into run_python_code calls."""
18831910
if tool_name == 'run_python_code':
1884-
# Auto-inject available tools if not already provided
1885-
if 'tools' not in arguments or not arguments['tools']:
1886-
available_tools: list[str] = []
1887-
1888-
# Add function tools
1889-
available_tools.extend(list(self._function_toolset.tools.keys()))
1890-
1891-
for toolset in self._user_toolsets:
1892-
if not isinstance(toolset, MCPServer):
1893-
continue
1894-
mcp_server = toolset
1895-
if 'mcp-run-python' in str(mcp_server):
1896-
continue
1897-
1898-
try:
1899-
server_tools = await mcp_server.list_tools()
1900-
for tool_def in server_tools:
1901-
python_name = tool_def.name.replace('-', '_')
1902-
available_tools.append(python_name)
1903-
except Exception:
1904-
# Silently continue if we can't get tools from a server
1905-
pass
1906-
1907-
arguments['tools'] = available_tools
1911+
# Always auto-inject all available tools for Python code execution
1912+
available_tools: list[str] = []
1913+
tool_name_mapping: dict[str, str] = {}
1914+
1915+
# Add function tools
1916+
function_tools = list(self._function_toolset.tools.keys())
1917+
available_tools.extend(function_tools)
1918+
for func_tool_name in function_tools:
1919+
tool_name_mapping[func_tool_name] = func_tool_name
1920+
1921+
# Add MCP server tools with proper name conversion
1922+
for toolset in self._user_toolsets:
1923+
if not isinstance(toolset, MCPServer):
1924+
continue
1925+
if 'mcp-run-python' in str(toolset):
1926+
continue
1927+
1928+
try:
1929+
server_tools = await toolset.list_tools()
1930+
for tool_def in server_tools:
1931+
original_name = tool_def.name
1932+
python_name = original_name.replace('-', '_')
1933+
available_tools.append(python_name)
1934+
tool_name_mapping[python_name] = original_name
1935+
except Exception:
1936+
# Silently continue if we can't get tools from a server
1937+
pass
1938+
1939+
# Always provide all available tools and mapping
1940+
arguments['tools'] = available_tools
1941+
arguments['tool_name_mapping'] = tool_name_mapping
19081942

19091943
# Continue with normal processing
19101944
return await call_tool_func(tool_name, arguments, None)
@@ -1935,17 +1969,16 @@ async def run_mcp_servers(
19351969

19361970
for toolset in self._user_toolsets:
19371971
if isinstance(toolset, MCPServer):
1938-
mcp_server = toolset
19391972
if (
1940-
hasattr(mcp_server, 'allow_elicitation')
1941-
and mcp_server.allow_elicitation
1942-
and mcp_server.elicitation_callback is None
1973+
hasattr(toolset, 'allow_elicitation')
1974+
and toolset.allow_elicitation
1975+
and toolset.elicitation_callback is None
19431976
):
1944-
mcp_server.elicitation_callback = self._create_elicitation_callback()
1977+
toolset.elicitation_callback = self._create_elicitation_callback()
19451978

19461979
# Also setup auto-tool-injection for run_python_code if not already set
1947-
if mcp_server.process_tool_call is None:
1948-
mcp_server.process_tool_call = self._create_auto_tool_injection_callback()
1980+
if toolset.process_tool_call is None:
1981+
toolset.process_tool_call = self._create_auto_tool_injection_callback()
19491982

19501983
async with self:
19511984
yield

tests/test_mcp_elicitation.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,15 @@ async def mock_elicitation(
8282
)
8383

8484
model = TestModel(custom_output_text='Test response')
85-
agent = Agent(model, mcp_servers=[server])
85+
agent = Agent(model, toolsets=[server])
8686

8787
# Verify the server is properly configured
88-
assert len(agent._mcp_servers) == 1 # type: ignore
89-
assert agent._mcp_servers[0].elicitation_callback is mock_elicitation # type: ignore
88+
toolsets = getattr(agent, '_user_toolsets', [])
89+
mcp_servers = [ts for ts in toolsets if hasattr(ts, '__class__') and 'MCPServer' in ts.__class__.__name__]
90+
assert len(mcp_servers) == 1
91+
# Use getattr to safely check elicitation_callback
92+
callback = getattr(mcp_servers[0], 'elicitation_callback', None)
93+
assert callback is mock_elicitation
9094

9195
async def test_elicitation_callback_error_handling(self):
9296
"""Test error handling in elicitation callback."""
@@ -483,14 +487,19 @@ async def agent_tool_callback(
483487

484488
# Create agent with the MCP server
485489
model = TestModel(custom_output_text='Tool injection test completed')
486-
agent = Agent(model, mcp_servers=[mcp_server])
490+
agent = Agent(model, toolsets=[mcp_server])
487491

488492
# Verify the agent has the MCP server with elicitation callback
489-
assert len(agent._mcp_servers) == 1 # type: ignore
490-
assert agent._mcp_servers[0].elicitation_callback is agent_tool_callback # type: ignore
493+
# Note: Using getattr to safely access toolsets for testing
494+
toolsets = getattr(agent, '_user_toolsets', [])
495+
mcp_servers = [ts for ts in toolsets if hasattr(ts, '__class__') and 'MCPServer' in ts.__class__.__name__]
496+
assert len(mcp_servers) == 1
497+
# Use getattr to safely check elicitation_callback
498+
callback = getattr(mcp_servers[0], 'elicitation_callback', None)
499+
assert callback is agent_tool_callback
491500

492501
# Test running agent with MCP servers
493-
async with agent.run_mcp_servers():
502+
async with agent:
494503
# Verify the MCP server is properly integrated
495504
tools = await mcp_server.list_tools()
496505
assert len(tools) == 1
@@ -780,7 +789,7 @@ async def test_mcp_run_python_code_execution(self):
780789

781790
async with server:
782791
# Test basic Python execution
783-
result = await server.call_tool(
792+
result = await server.direct_call_tool(
784793
'run_python_code', {'python_code': 'print("Hello, World!")\n"Hello from Python"'}
785794
)
786795

@@ -827,7 +836,7 @@ async def python_code_callback(
827836
async with server:
828837
# Test Python code execution with tool injection
829838
# This should trigger the elicitation callback when tools are called
830-
result = await server.call_tool(
839+
result = await server.direct_call_tool(
831840
'run_python_code',
832841
{'python_code': 'print("Testing tool injection")', 'tools': ['web_search', 'calculate']},
833842
)
@@ -854,7 +863,7 @@ async def test_mcp_run_python_error_handling(self):
854863

855864
async with server:
856865
# Test Python code with syntax error
857-
result = await server.call_tool('run_python_code', {'python_code': 'print("Missing closing quote)'})
866+
result = await server.direct_call_tool('run_python_code', {'python_code': 'print("Missing closing quote)'})
858867

859868
# Should return error status
860869
assert isinstance(result, str)
@@ -1009,7 +1018,7 @@ async def test_mcp_run_python_with_dependencies(self):
10091018

10101019
async with server:
10111020
# Test code with dependencies
1012-
result = await server.call_tool(
1021+
result = await server.direct_call_tool(
10131022
'run_python_code',
10141023
{
10151024
'python_code': """
@@ -1049,7 +1058,8 @@ async def test_mcp_run_python_with_tool_prefix(self):
10491058
async with server:
10501059
tools = await server.list_tools()
10511060
assert len(tools) == 1
1052-
assert tools[0].name == 'python_run_python_code'
1061+
# list_tools() returns original tool names without prefix
1062+
assert tools[0].name == 'run_python_code'
10531063

10541064
async def test_mcp_run_python_timeout_setting(self):
10551065
"""Test mcp-run-python server with timeout setting."""
@@ -1071,7 +1081,7 @@ async def test_mcp_run_python_timeout_setting(self):
10711081

10721082
async with server:
10731083
# Test basic execution still works
1074-
result = await server.call_tool('run_python_code', {'python_code': 'print("Timeout test")'})
1084+
result = await server.direct_call_tool('run_python_code', {'python_code': 'print("Timeout test")'})
10751085

10761086
assert isinstance(result, str)
10771087
assert '<status>success</status>' in result

0 commit comments

Comments
 (0)