Skip to content

Commit 797cd92

Browse files
committed
refactor: simplify MCP elicitation by removing tool name mapping complexity
1 parent 614b477 commit 797cd92

File tree

4 files changed

+435
-150
lines changed

4 files changed

+435
-150
lines changed

mcp-run-python/src/tool_injection.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,14 @@ def _create_tool_function(
6464
"""Create a tool function that can be called from Python."""
6565

6666
def tool_function(*args: Any, **kwargs: Any) -> Any:
67-
"""Synchronous tool function that handles the async callback properly."""
68-
69-
# Get the actual MCP tool name from the stored mapping
70-
tool_mapping = globals_dict.get('__tool_name_mapping__', {})
71-
actual_tool_name = tool_mapping.get(tool_name, tool_name)
72-
73-
elicitation_request = _create_elicitation_request(actual_tool_name, args, kwargs)
67+
"""Tool function that calls the MCP elicitation callback."""
68+
elicitation_request = _create_elicitation_request(tool_name=tool_name, args=args, kwargs=kwargs)
7469

7570
try:
7671
result = tool_callback(elicitation_request)
77-
return _handle_tool_callback_result(result, actual_tool_name)
72+
return _handle_tool_callback_result(result, tool_name)
7873
except Exception as e:
79-
raise Exception(f'Tool {actual_tool_name} failed: {str(e)}')
74+
raise Exception(f'Tool {tool_name} failed: {str(e)}')
8075

8176
return tool_function
8277

@@ -85,25 +80,18 @@ def inject_tool_functions(
8580
globals_dict: dict[str, Any],
8681
available_tools: list[str],
8782
tool_callback: Callable[[Any], Any] | None = None,
88-
tool_name_mapping: dict[str, str] | None = None,
8983
) -> None:
9084
"""Inject tool functions into the global namespace.
9185
9286
Args:
9387
globals_dict: Global namespace to inject tools into
94-
available_tools: List of available tool names (should be Python-valid identifiers)
88+
available_tools: List of available tool names
9589
tool_callback: Optional callback for tool execution
96-
tool_name_mapping: Optional mapping of python_name -> original_mcp_name
9790
"""
9891
if not available_tools:
9992
return
10093

101-
# Store the tool name mapping globally for elicitation callback to use
102-
if tool_name_mapping:
103-
globals_dict['__tool_name_mapping__'] = tool_name_mapping
104-
105-
# Inject tool functions into globals using Python-valid names
10694
for tool_name in available_tools:
10795
if tool_callback is not None:
108-
# tool_name should already be a valid Python identifier from agent.py
109-
globals_dict[tool_name] = _create_tool_function(tool_name, tool_callback, globals_dict)
96+
python_name = tool_name.replace('-', '_')
97+
globals_dict[python_name] = _create_tool_function(tool_name, tool_callback, globals_dict)

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 44 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from types import FrameType
1313
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final, overload
1414

15-
from mcp import types as mcp_types
1615
from opentelemetry.trace import NoOpTracer, use_span
1716
from pydantic.json_schema import GenerateJsonSchema
1817
from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated
@@ -34,7 +33,7 @@
3433
from ._agent_graph import HistoryProcessor
3534
from ._output import OutputToolset
3635
from ._tool_manager import ToolManager
37-
from .mcp import MCPServer
36+
from .mcp import MCPServer, create_auto_tool_injection_callback, create_tool_elicitation_callback
3837
from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model
3938
from .output import OutputDataT, OutputSpec
4039
from .profiles import ModelProfile
@@ -1697,7 +1696,6 @@ def _get_toolset(
16971696
if self._prepare_output_tools:
16981697
output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools)
16991698
all_toolsets = [output_toolset, *all_toolsets]
1700-
17011699
return CombinedToolset(all_toolsets)
17021700

17031701
def _infer_name(self, function_frame: FrameType | None) -> None:
@@ -1825,113 +1823,50 @@ def _set_sampling_model(toolset: AbstractToolset[AgentDepsT]) -> None:
18251823

18261824
self._get_toolset().apply(_set_sampling_model)
18271825

1828-
def _create_elicitation_callback(self):
1829-
"""Create an elicitation callback that routes to this agent's tools."""
1826+
def set_mcp_elicitation_toolset(self, toolset_for_elicitation: AbstractToolset[Any] | None = None) -> None:
1827+
"""Set the toolset to use for MCP elicitation callbacks.
18301828
1831-
async def elicitation_callback(context: Any, params: Any) -> Any:
1832-
"""Handle elicitation requests by delegating to the agent's tools."""
1833-
try:
1834-
tool_execution_data = json.loads(params.message)
1835-
tool_name = tool_execution_data.get('tool_name')
1836-
tool_arguments = tool_execution_data.get('arguments', {})
1837-
1838-
# Try function tools first
1839-
function_tools = self._function_toolset.tools
1840-
if tool_name in function_tools:
1841-
tool_func = function_tools[tool_name].function_schema.function
1842-
1843-
# Handle both sync and async functions
1844-
1845-
if inspect.iscoroutinefunction(tool_func):
1846-
result = await tool_func(**tool_arguments)
1847-
else:
1848-
result = tool_func(**tool_arguments)
1849-
1850-
return mcp_types.ElicitResult(action='accept', content={'result': str(result)})
1851-
1852-
# Find the MCP server that has this tool
1853-
target_server = None
1854-
for toolset in self._user_toolsets:
1855-
if not isinstance(toolset, MCPServer):
1856-
continue
1857-
if 'mcp-run-python' in str(toolset):
1858-
continue
1859-
1860-
# Check if this server has the tool
1861-
try:
1862-
server_tools = await toolset.list_tools()
1863-
for tool_def in server_tools:
1864-
if tool_def.name == tool_name:
1865-
target_server = toolset
1866-
break
1867-
if target_server:
1868-
break
1869-
except Exception:
1870-
continue
1871-
1872-
if target_server:
1873-
try:
1874-
result = await target_server.direct_call_tool(tool_name, tool_arguments)
1875-
return mcp_types.ElicitResult(action='accept', content={'result': str(result)})
1876-
except Exception as e:
1877-
return mcp_types.ErrorData(
1878-
code=mcp_types.INTERNAL_ERROR, message=f'Tool execution failed: {str(e)}'
1879-
)
1880-
else:
1881-
return mcp_types.ErrorData(code=mcp_types.INVALID_PARAMS, message=f'Tool {tool_name} not found')
1882-
1883-
except Exception as e:
1884-
return mcp_types.ErrorData(code=mcp_types.INTERNAL_ERROR, message=f'Tool execution failed: {str(e)}')
1885-
1886-
return elicitation_callback
1887-
1888-
def _create_auto_tool_injection_callback(self):
1889-
"""Create a callback that auto-injects available tools into run_python_code calls."""
1890-
1891-
async def auto_inject_tools_callback(
1892-
ctx: RunContext[Any],
1893-
call_tool_func: Callable[[str, dict[str, Any], dict[str, Any] | None], Awaitable[Any]],
1894-
tool_name: str,
1895-
arguments: dict[str, Any],
1896-
) -> Any:
1897-
"""Auto-inject available tools into run_python_code calls."""
1898-
if tool_name == 'run_python_code':
1899-
# Always auto-inject all available tools for Python code execution
1900-
available_tools: list[str] = []
1901-
tool_name_mapping: dict[str, str] = {}
1902-
1903-
# Add function tools
1904-
function_tools = list(self._function_toolset.tools.keys())
1905-
available_tools.extend(function_tools)
1906-
for func_tool_name in function_tools:
1907-
tool_name_mapping[func_tool_name] = func_tool_name
1908-
1909-
# Add MCP server tools with proper name conversion
1910-
for toolset in self._user_toolsets:
1911-
if not isinstance(toolset, MCPServer):
1912-
continue
1913-
if 'mcp-run-python' in str(toolset):
1914-
continue
1915-
1916-
try:
1917-
server_tools = await toolset.list_tools()
1918-
for tool_def in server_tools:
1919-
original_name = tool_def.name
1920-
python_name = original_name.replace('-', '_')
1921-
available_tools.append(python_name)
1922-
tool_name_mapping[python_name] = original_name
1923-
except Exception:
1924-
# Silently continue if we can't get tools from a server
1925-
pass
1926-
1927-
# Always provide all available tools and mapping
1928-
arguments['tools'] = available_tools
1929-
arguments['tool_name_mapping'] = tool_name_mapping
1930-
1931-
# Continue with normal processing
1932-
return await call_tool_func(tool_name, arguments, None)
1933-
1934-
return auto_inject_tools_callback
1829+
This method configures all MCP servers in the agent's toolsets to use the provided
1830+
toolset for handling elicitation requests (tool injection). This enables Python code
1831+
executed via mcp-run-python to call back to the agent's tools.
1832+
1833+
Args:
1834+
toolset_for_elicitation: Toolset to use for tool injection via elicitation.
1835+
If None, uses the agent's complete toolset.
1836+
1837+
Example:
1838+
```python
1839+
agent = Agent('openai:gpt-4o')
1840+
agent.tool(web_search)
1841+
agent.tool(send_email)
1842+
1843+
mcp_server = MCPServerStdio(command='deno', args=[...], allow_elicitation=True)
1844+
agent.add_toolset(mcp_server)
1845+
1846+
# Enable tool injection with all agent tools
1847+
agent.set_mcp_elicitation_toolset()
1848+
1849+
# Or use specific toolset
1850+
custom_toolset = FunctionToolset(web_search)
1851+
agent.set_mcp_elicitation_toolset(custom_toolset)
1852+
```
1853+
"""
1854+
if toolset_for_elicitation is None:
1855+
# Use complete toolset for both elicitation and injection
1856+
toolset_for_elicitation = self._get_toolset()
1857+
1858+
# Set up callbacks for all MCP servers
1859+
def _set_elicitation_toolset(toolset: AbstractToolset[Any]) -> None:
1860+
if isinstance(toolset, MCPServer) and toolset.allow_elicitation:
1861+
# Set up elicitation callback
1862+
if toolset.elicitation_callback is None:
1863+
toolset.elicitation_callback = create_tool_elicitation_callback(toolset=toolset_for_elicitation)
1864+
1865+
# Set up tool injection callback
1866+
if toolset.process_tool_call is None:
1867+
toolset.process_tool_call = create_auto_tool_injection_callback(toolset=toolset_for_elicitation)
1868+
1869+
self._get_toolset().apply(_set_elicitation_toolset)
19351870

19361871
@asynccontextmanager
19371872
@deprecated(
@@ -1953,21 +1888,6 @@ async def run_mcp_servers(
19531888
if model is not None:
19541889
raise
19551890

1956-
# Auto-setup elicitation callback if allow_elicitation is True and no callback is set
1957-
1958-
for toolset in self._user_toolsets:
1959-
if isinstance(toolset, MCPServer):
1960-
if (
1961-
hasattr(toolset, 'allow_elicitation')
1962-
and toolset.allow_elicitation
1963-
and toolset.elicitation_callback is None
1964-
):
1965-
toolset.elicitation_callback = self._create_elicitation_callback()
1966-
1967-
# Also setup auto-tool-injection for run_python_code if not already set
1968-
if toolset.process_tool_call is None:
1969-
toolset.process_tool_call = self._create_auto_tool_injection_callback()
1970-
19711891
async with self:
19721892
yield
19731893

0 commit comments

Comments
 (0)