diff --git a/langchain_mcp_adapters/tools.py b/langchain_mcp_adapters/tools.py index 0c6fa45..7d041aa 100644 --- a/langchain_mcp_adapters/tools.py +++ b/langchain_mcp_adapters/tools.py @@ -150,8 +150,14 @@ async def load_mcp_tools( session: ClientSession | None, *, connection: Connection | None = None, + tool_names: list[str] | None = None, ) -> list[BaseTool]: - """Load all available MCP tools and convert them to LangChain tools. + """Load selected or all MCP tools and convert them to LangChain tools. + + Args: + session: MCP client session + connection: Optional connection config to use to create a new session if a `session` is not provided + tool_names: List of specific tool names to load. If empty or None, load all tools. Args: session: The MCP client session. If None, connection must be provided. @@ -176,6 +182,10 @@ async def load_mcp_tools( else: tools = await _list_all_tools(session) + # If tool_names is provided and not empty, filter tools by name + if tool_names: + tools = [tool for tool in tools if tool.name in tool_names] + return [ convert_mcp_tool_to_langchain_tool(session, tool, connection=connection) for tool in tools diff --git a/tests/test_tools.py b/tests/test_tools.py index ba73414..c8ee7eb 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -454,3 +454,71 @@ def custom_httpx_client_factory( # Expected to fail since server doesn't have SSE endpoint, # but the important thing is that httpx_client_factory was passed correctly pass + + +@pytest.mark.asyncio +async def test_load_mcp_tools_with_no_tool_names_e2e(socket_enabled) -> None: + """E2E: Load all MCP tools from a running server.""" + from mcp.server import FastMCP + + server = FastMCP(port=8184) + + @server.tool() + def tool1(param1: str, param2: int) -> str: + return f"tool1 result with {param1}, {param2}" + + @server.tool() + def tool2(param1: str, param2: int) -> str: + return f"tool2 result with {param1}, {param2}" + + @server.tool() + def tool3(param1: str, param2: int) -> str: + return f"tool3 result with {param1}, {param2}" + + with run_streamable_http(server): + client = MultiServerMCPClient( + { + "alltools": { + "url": "http://localhost:8184/mcp/", + "transport": "streamable_http", + }, + } + ) + async with client.session("alltools") as session: + tools = await load_mcp_tools(session) + assert len(tools) == 3 + assert {t.name for t in tools} == {"tool1", "tool2", "tool3"} + + +@pytest.mark.asyncio +async def test_load_mcp_tools_with_specific_tool_names_e2e(socket_enabled) -> None: + """E2E: Load only specific MCP tools from a running server.""" + from mcp.server import FastMCP + + server = FastMCP(port=8185) + + @server.tool() + def tool1(param1: str, param2: int) -> str: + return f"tool1 result with {param1}, {param2}" + + @server.tool() + def tool2(param1: str, param2: int) -> str: + return f"tool2 result with {param1}, {param2}" + + @server.tool() + def tool3(param1: str, param2: int) -> str: + return f"tool3 result with {param1}, {param2}" + + with run_streamable_http(server): + client = MultiServerMCPClient( + { + "sometools": { + "url": "http://localhost:8185/mcp/", + "transport": "streamable_http", + }, + } + ) + async with client.session("sometools") as session: + tools = await load_mcp_tools(session, tool_names=["tool1", "tool3"]) + assert len(tools) == 2 + assert {t.name for t in tools} == {"tool1", "tool3"}