Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions langchain_mcp_adapters/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
"""

import asyncio
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Callable
from contextlib import asynccontextmanager
from types import TracebackType
from typing import Any

from langchain_core.documents.base import Blob
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.tools import BaseTool
from langchain_core.tools import BaseTool, ToolException
from mcp import ClientSession
from pydantic import ValidationError

from langchain_mcp_adapters.prompts import load_mcp_prompt
from langchain_mcp_adapters.resources import load_mcp_resources
Expand Down Expand Up @@ -125,12 +126,22 @@ async def session(
await session.initialize()
yield session

async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]:
async def get_tools(
self,
*,
server_name: str | None = None,
handle_tool_error: bool | str | Callable[[ToolException], str] | None = False,
handle_validation_error: (
bool | str | Callable[[ValidationError], str] | None
) = False,
) -> list[BaseTool]:
"""Get a list of all tools from all connected servers.

Args:
server_name: Optional name of the server to get tools from.
If None, all tools from all servers will be returned (default).
handle_tool_error: Optional error handler for tool execution errors.
handle_validation_error: Optional error handler for validation errors.

NOTE: a new session will be created for each tool call

Expand All @@ -145,13 +156,23 @@ async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]:
f"expected one of '{list(self.connections.keys())}'"
)
raise ValueError(msg)
return await load_mcp_tools(None, connection=self.connections[server_name])
return await load_mcp_tools(
None,
connection=self.connections[server_name],
handle_tool_error=handle_tool_error,
handle_validation_error=handle_validation_error,
)

all_tools: list[BaseTool] = []
load_mcp_tool_tasks = []
for connection in self.connections.values():
load_mcp_tool_task = asyncio.create_task(
load_mcp_tools(None, connection=connection)
load_mcp_tools(
None,
connection=connection,
handle_tool_error=handle_tool_error,
handle_validation_error=handle_validation_error,
)
)
load_mcp_tool_tasks.append(load_mcp_tool_task)
tools_list = await asyncio.gather(*load_mcp_tool_tasks)
Expand Down
25 changes: 23 additions & 2 deletions langchain_mcp_adapters/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
tools, handle tool execution, and manage tool conversion between the two formats.
"""

from collections.abc import Callable
from typing import Any, cast, get_args

from langchain_core.tools import (
Expand All @@ -18,7 +19,7 @@
from mcp.server.fastmcp.utilities.func_metadata import ArgModelBase, FuncMetadata
from mcp.types import CallToolResult, EmbeddedResource, ImageContent, TextContent
from mcp.types import Tool as MCPTool
from pydantic import BaseModel, create_model
from pydantic import BaseModel, ValidationError, create_model

from langchain_mcp_adapters.sessions import Connection, create_session

Expand Down Expand Up @@ -102,6 +103,10 @@ def convert_mcp_tool_to_langchain_tool(
tool: MCPTool,
*,
connection: Connection | None = None,
handle_tool_error: bool | str | Callable[[ToolException], str] | None = False,
handle_validation_error: (
bool | str | Callable[[ValidationError], str] | None
) = False,
) -> BaseTool:
"""Convert an MCP tool to a LangChain tool.

Expand All @@ -112,6 +117,8 @@ def convert_mcp_tool_to_langchain_tool(
tool: MCP tool to convert
connection: Optional connection config to use to create a new session
if a `session` is not provided
handle_tool_error: Optional error handler for tool execution errors.
handle_validation_error: Optional error handler for validation errors.

Returns:
a LangChain tool
Expand Down Expand Up @@ -143,19 +150,27 @@ async def call_tool(
coroutine=call_tool,
response_format="content_and_artifact",
metadata=tool.annotations.model_dump() if tool.annotations else None,
handle_tool_error=handle_tool_error,
handle_validation_error=handle_validation_error,
)


async def load_mcp_tools(
session: ClientSession | None,
*,
connection: Connection | None = None,
handle_tool_error: bool | str | Callable[[ToolException], str] | None = False,
handle_validation_error: (
bool | str | Callable[[ValidationError], str] | None
) = False,
) -> list[BaseTool]:
"""Load all available MCP tools and convert them to LangChain tools.

Args:
session: The MCP client session. If None, connection must be provided.
connection: Connection config to create a new session if session is None.
handle_tool_error: Optional error handler for tool execution errors.
handle_validation_error: Optional error handler for validation errors.

Returns:
List of LangChain tools. Tool annotations are returned as part
Expand All @@ -177,7 +192,13 @@ async def load_mcp_tools(
tools = await _list_all_tools(session)

return [
convert_mcp_tool_to_langchain_tool(session, tool, connection=connection)
convert_mcp_tool_to_langchain_tool(
session,
tool,
connection=connection,
handle_tool_error=handle_tool_error,
handle_validation_error=handle_validation_error,
)
for tool in tools
]

Expand Down