Skip to content

Commit ae7b464

Browse files
committed
refactor: simplify MCP elicitation by removing tool name mapping complexity
1 parent 84a60a2 commit ae7b464

File tree

4 files changed

+435
-163
lines changed

4 files changed

+435
-163
lines changed

mcp-run-python/src/tool_injection.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,14 @@ def _create_tool_function(
6464
"""Create a tool function that can be called from Python."""
6565

6666
def tool_function(*args: Any, **kwargs: Any) -> Any:
67-
"""Synchronous tool function that handles the async callback properly."""
68-
69-
# Get the actual MCP tool name from the stored mapping
70-
tool_mapping = globals_dict.get('__tool_name_mapping__', {})
71-
actual_tool_name = tool_mapping.get(tool_name, tool_name)
72-
73-
elicitation_request = _create_elicitation_request(actual_tool_name, args, kwargs)
67+
"""Tool function that calls the MCP elicitation callback."""
68+
elicitation_request = _create_elicitation_request(tool_name=tool_name, args=args, kwargs=kwargs)
7469

7570
try:
7671
result = tool_callback(elicitation_request)
77-
return _handle_tool_callback_result(result, actual_tool_name)
72+
return _handle_tool_callback_result(result, tool_name)
7873
except Exception as e:
79-
raise Exception(f'Tool {actual_tool_name} failed: {str(e)}')
74+
raise Exception(f'Tool {tool_name} failed: {str(e)}')
8075

8176
return tool_function
8277

@@ -85,25 +80,18 @@ def inject_tool_functions(
8580
globals_dict: dict[str, Any],
8681
available_tools: list[str],
8782
tool_callback: Callable[[Any], Any] | None = None,
88-
tool_name_mapping: dict[str, str] | None = None,
8983
) -> None:
9084
"""Inject tool functions into the global namespace.
9185
9286
Args:
9387
globals_dict: Global namespace to inject tools into
94-
available_tools: List of available tool names (should be Python-valid identifiers)
88+
available_tools: List of available tool names
9589
tool_callback: Optional callback for tool execution
96-
tool_name_mapping: Optional mapping of python_name -> original_mcp_name
9790
"""
9891
if not available_tools:
9992
return
10093

101-
# Store the tool name mapping globally for elicitation callback to use
102-
if tool_name_mapping:
103-
globals_dict['__tool_name_mapping__'] = tool_name_mapping
104-
105-
# Inject tool functions into globals using Python-valid names
10694
for tool_name in available_tools:
10795
if tool_callback is not None:
108-
# tool_name should already be a valid Python identifier from agent.py
109-
globals_dict[tool_name] = _create_tool_function(tool_name, tool_callback, globals_dict)
96+
python_name = tool_name.replace('-', '_')
97+
globals_dict[python_name] = _create_tool_function(tool_name, tool_callback, globals_dict)

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 44 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from types import FrameType
1313
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final, overload
1414

15-
from mcp import types as mcp_types
1615
from opentelemetry.trace import NoOpTracer, use_span
1716
from pydantic.json_schema import GenerateJsonSchema
1817
from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated
@@ -34,7 +33,7 @@
3433
from ._agent_graph import HistoryProcessor
3534
from ._output import OutputToolset
3635
from ._tool_manager import ToolManager
37-
from .mcp import MCPServer
36+
from .mcp import MCPServer, create_auto_tool_injection_callback, create_tool_elicitation_callback
3837
from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model
3938
from .output import OutputDataT, OutputSpec
4039
from .profiles import ModelProfile
@@ -1697,7 +1696,6 @@ def _get_toolset(
16971696
if self._prepare_output_tools:
16981697
output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools)
16991698
all_toolsets = [output_toolset, *all_toolsets]
1700-
17011699
return CombinedToolset(all_toolsets)
17021700

17031701
def _infer_name(self, function_frame: FrameType | None) -> None:
@@ -1796,19 +1794,6 @@ async def __aenter__(self) -> Self:
17961794
if self._entered_count == 0:
17971795
self._exit_stack = AsyncExitStack()
17981796

1799-
for toolset in self._user_toolsets:
1800-
if isinstance(toolset, MCPServer):
1801-
if (
1802-
hasattr(toolset, 'allow_elicitation')
1803-
and toolset.allow_elicitation
1804-
and toolset.elicitation_callback is None
1805-
):
1806-
toolset.elicitation_callback = self._create_elicitation_callback()
1807-
1808-
# Also setup auto-tool-injection for run_python_code if not already set
1809-
if toolset.process_tool_call is None:
1810-
toolset.process_tool_call = self._create_auto_tool_injection_callback()
1811-
18121797
toolset = self._get_toolset()
18131798
await self._exit_stack.enter_async_context(toolset)
18141799
self._entered_count += 1
@@ -1837,113 +1822,50 @@ def _set_sampling_model(toolset: AbstractToolset[AgentDepsT]) -> None:
18371822

18381823
self._get_toolset().apply(_set_sampling_model)
18391824

1840-
def _create_elicitation_callback(self):
1841-
"""Create an elicitation callback that routes to this agent's tools."""
1825+
def set_mcp_elicitation_toolset(self, toolset_for_elicitation: AbstractToolset[Any] | None = None) -> None:
1826+
"""Set the toolset to use for MCP elicitation callbacks.
18421827
1843-
async def elicitation_callback(context: Any, params: Any) -> Any:
1844-
"""Handle elicitation requests by delegating to the agent's tools."""
1845-
try:
1846-
tool_execution_data = json.loads(params.message)
1847-
tool_name = tool_execution_data.get('tool_name')
1848-
tool_arguments = tool_execution_data.get('arguments', {})
1849-
1850-
# Try function tools first
1851-
function_tools = self._function_toolset.tools
1852-
if tool_name in function_tools:
1853-
tool_func = function_tools[tool_name].function_schema.function
1854-
1855-
# Handle both sync and async functions
1856-
1857-
if inspect.iscoroutinefunction(tool_func):
1858-
result = await tool_func(**tool_arguments)
1859-
else:
1860-
result = tool_func(**tool_arguments)
1861-
1862-
return mcp_types.ElicitResult(action='accept', content={'result': str(result)})
1863-
1864-
# Find the MCP server that has this tool
1865-
target_server = None
1866-
for toolset in self._user_toolsets:
1867-
if not isinstance(toolset, MCPServer):
1868-
continue
1869-
if 'mcp-run-python' in str(toolset):
1870-
continue
1871-
1872-
# Check if this server has the tool
1873-
try:
1874-
server_tools = await toolset.list_tools()
1875-
for tool_def in server_tools:
1876-
if tool_def.name == tool_name:
1877-
target_server = toolset
1878-
break
1879-
if target_server:
1880-
break
1881-
except Exception:
1882-
continue
1883-
1884-
if target_server:
1885-
try:
1886-
result = await target_server.direct_call_tool(tool_name, tool_arguments)
1887-
return mcp_types.ElicitResult(action='accept', content={'result': str(result)})
1888-
except Exception as e:
1889-
return mcp_types.ErrorData(
1890-
code=mcp_types.INTERNAL_ERROR, message=f'Tool execution failed: {str(e)}'
1891-
)
1892-
else:
1893-
return mcp_types.ErrorData(code=mcp_types.INVALID_PARAMS, message=f'Tool {tool_name} not found')
1894-
1895-
except Exception as e:
1896-
return mcp_types.ErrorData(code=mcp_types.INTERNAL_ERROR, message=f'Tool execution failed: {str(e)}')
1897-
1898-
return elicitation_callback
1899-
1900-
def _create_auto_tool_injection_callback(self):
1901-
"""Create a callback that auto-injects available tools into run_python_code calls."""
1902-
1903-
async def auto_inject_tools_callback(
1904-
ctx: RunContext[Any],
1905-
call_tool_func: Callable[[str, dict[str, Any], dict[str, Any] | None], Awaitable[Any]],
1906-
tool_name: str,
1907-
arguments: dict[str, Any],
1908-
) -> Any:
1909-
"""Auto-inject available tools into run_python_code calls."""
1910-
if tool_name == 'run_python_code':
1911-
# Always auto-inject all available tools for Python code execution
1912-
available_tools: list[str] = []
1913-
tool_name_mapping: dict[str, str] = {}
1914-
1915-
# Add function tools
1916-
function_tools = list(self._function_toolset.tools.keys())
1917-
available_tools.extend(function_tools)
1918-
for func_tool_name in function_tools:
1919-
tool_name_mapping[func_tool_name] = func_tool_name
1920-
1921-
# Add MCP server tools with proper name conversion
1922-
for toolset in self._user_toolsets:
1923-
if not isinstance(toolset, MCPServer):
1924-
continue
1925-
if 'mcp-run-python' in str(toolset):
1926-
continue
1927-
1928-
try:
1929-
server_tools = await toolset.list_tools()
1930-
for tool_def in server_tools:
1931-
original_name = tool_def.name
1932-
python_name = original_name.replace('-', '_')
1933-
available_tools.append(python_name)
1934-
tool_name_mapping[python_name] = original_name
1935-
except Exception:
1936-
# Silently continue if we can't get tools from a server
1937-
pass
1938-
1939-
# Always provide all available tools and mapping
1940-
arguments['tools'] = available_tools
1941-
arguments['tool_name_mapping'] = tool_name_mapping
1942-
1943-
# Continue with normal processing
1944-
return await call_tool_func(tool_name, arguments, None)
1945-
1946-
return auto_inject_tools_callback
1828+
This method configures all MCP servers in the agent's toolsets to use the provided
1829+
toolset for handling elicitation requests (tool injection). This enables Python code
1830+
executed via mcp-run-python to call back to the agent's tools.
1831+
1832+
Args:
1833+
toolset_for_elicitation: Toolset to use for tool injection via elicitation.
1834+
If None, uses the agent's complete toolset.
1835+
1836+
Example:
1837+
```python
1838+
agent = Agent('openai:gpt-4o')
1839+
agent.tool(web_search)
1840+
agent.tool(send_email)
1841+
1842+
mcp_server = MCPServerStdio(command='deno', args=[...], allow_elicitation=True)
1843+
agent.add_toolset(mcp_server)
1844+
1845+
# Enable tool injection with all agent tools
1846+
agent.set_mcp_elicitation_toolset()
1847+
1848+
# Or use specific toolset
1849+
custom_toolset = FunctionToolset(web_search)
1850+
agent.set_mcp_elicitation_toolset(custom_toolset)
1851+
```
1852+
"""
1853+
if toolset_for_elicitation is None:
1854+
# Use complete toolset for both elicitation and injection
1855+
toolset_for_elicitation = self._get_toolset()
1856+
1857+
# Set up callbacks for all MCP servers
1858+
def _set_elicitation_toolset(toolset: AbstractToolset[Any]) -> None:
1859+
if isinstance(toolset, MCPServer) and toolset.allow_elicitation:
1860+
# Set up elicitation callback
1861+
if toolset.elicitation_callback is None:
1862+
toolset.elicitation_callback = create_tool_elicitation_callback(toolset=toolset_for_elicitation)
1863+
1864+
# Set up tool injection callback
1865+
if toolset.process_tool_call is None:
1866+
toolset.process_tool_call = create_auto_tool_injection_callback(toolset=toolset_for_elicitation)
1867+
1868+
self._get_toolset().apply(_set_elicitation_toolset)
19471869

19481870
@asynccontextmanager
19491871
@deprecated(
@@ -1965,21 +1887,6 @@ async def run_mcp_servers(
19651887
if model is not None:
19661888
raise
19671889

1968-
# Auto-setup elicitation callback if allow_elicitation is True and no callback is set
1969-
1970-
for toolset in self._user_toolsets:
1971-
if isinstance(toolset, MCPServer):
1972-
if (
1973-
hasattr(toolset, 'allow_elicitation')
1974-
and toolset.allow_elicitation
1975-
and toolset.elicitation_callback is None
1976-
):
1977-
toolset.elicitation_callback = self._create_elicitation_callback()
1978-
1979-
# Also setup auto-tool-injection for run_python_code if not already set
1980-
if toolset.process_tool_call is None:
1981-
toolset.process_tool_call = self._create_auto_tool_injection_callback()
1982-
19831890
async with self:
19841891
yield
19851892

0 commit comments

Comments
 (0)