diff --git a/langchain_mcp_adapters/tools.py b/langchain_mcp_adapters/tools.py index f9f62e9..5dbe1cd 100644 --- a/langchain_mcp_adapters/tools.py +++ b/langchain_mcp_adapters/tools.py @@ -1,4 +1,4 @@ -from typing import Any, cast, get_args +from typing import Any, Callable, cast, get_args from langchain_core.tools import BaseTool, InjectedToolArg, StructuredTool, ToolException from langchain_core.tools.base import get_all_basemodel_annotations @@ -73,6 +73,7 @@ def convert_mcp_tool_to_langchain_tool( tool: MCPTool, *, connection: Connection | None = None, + prepare_arguments: Callable[[MCPTool, dict], dict] | None = None, ) -> BaseTool: """Convert an MCP tool to a LangChain tool. @@ -83,6 +84,9 @@ def convert_mcp_tool_to_langchain_tool( tool: MCP tool to convert connection: Optional connection config to use to create a new session if a `session` is not provided + prepare_arguments: Optional hook to modify the arguments before calling the MCP + tool. It should accept the tool and the arguments as parameters and return + the modified arguments. Returns: a LangChain tool @@ -93,12 +97,15 @@ def convert_mcp_tool_to_langchain_tool( async def call_tool( **arguments: dict[str, Any], ) -> tuple[str | list[str], list[NonTextContent] | None]: + """Call the MCP tool with the provided arguments.""" + arguments_ = prepare_arguments(tool, arguments) if prepare_arguments else arguments + if session is None: # If a session is not provided, we will create one on the fly async with create_session(connection) as tool_session: await tool_session.initialize() call_tool_result = await cast(ClientSession, tool_session).call_tool( - tool.name, arguments + tool.name, arguments_ ) else: call_tool_result = await session.call_tool(tool.name, arguments) @@ -118,6 +125,7 @@ async def load_mcp_tools( session: ClientSession | None, *, connection: Connection | None = None, + prepare_arguments: Callable[[MCPTool, dict], dict] | None = None, ) -> list[BaseTool]: """Load all available MCP tools and convert them to LangChain tools. @@ -137,7 +145,10 @@ async def load_mcp_tools( tools = await _list_all_tools(session) converted_tools = [ - convert_mcp_tool_to_langchain_tool(session, tool, connection=connection) for tool in tools + convert_mcp_tool_to_langchain_tool( + session, tool, connection=connection, prepare_arguments=prepare_arguments + ) + for tool in tools ] return converted_tools