Skip to content

wip: prepare arguments for tool call #218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions langchain_mcp_adapters/tools.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, cast, get_args
from typing import Any, Callable, cast, get_args

from langchain_core.tools import BaseTool, InjectedToolArg, StructuredTool, ToolException
from langchain_core.tools.base import get_all_basemodel_annotations
Expand Down Expand Up @@ -73,6 +73,7 @@ def convert_mcp_tool_to_langchain_tool(
tool: MCPTool,
*,
connection: Connection | None = None,
prepare_arguments: Callable[[MCPTool, dict], dict] | None = None,
) -> BaseTool:
"""Convert an MCP tool to a LangChain tool.

Expand All @@ -83,6 +84,9 @@ def convert_mcp_tool_to_langchain_tool(
tool: MCP tool to convert
connection: Optional connection config to use to create a new session
if a `session` is not provided
prepare_arguments: Optional hook to modify the arguments before calling the MCP
tool. It should accept the tool and the arguments as parameters and return
the modified arguments.

Returns:
a LangChain tool
Expand All @@ -93,12 +97,15 @@ def convert_mcp_tool_to_langchain_tool(
async def call_tool(
**arguments: dict[str, Any],
) -> tuple[str | list[str], list[NonTextContent] | None]:
"""Call the MCP tool with the provided arguments."""
arguments_ = prepare_arguments(tool, arguments) if prepare_arguments else arguments

if session is None:
# If a session is not provided, we will create one on the fly
async with create_session(connection) as tool_session:
await tool_session.initialize()
call_tool_result = await cast(ClientSession, tool_session).call_tool(
tool.name, arguments
tool.name, arguments_
)
else:
call_tool_result = await session.call_tool(tool.name, arguments)
Expand All @@ -118,6 +125,7 @@ async def load_mcp_tools(
session: ClientSession | None,
*,
connection: Connection | None = None,
prepare_arguments: Callable[[MCPTool, dict], dict] | None = None,
) -> list[BaseTool]:
"""Load all available MCP tools and convert them to LangChain tools.

Expand All @@ -137,7 +145,10 @@ async def load_mcp_tools(
tools = await _list_all_tools(session)

converted_tools = [
convert_mcp_tool_to_langchain_tool(session, tool, connection=connection) for tool in tools
convert_mcp_tool_to_langchain_tool(
session, tool, connection=connection, prepare_arguments=prepare_arguments
)
for tool in tools
]
return converted_tools

Expand Down