diff --git a/langchain_mcp_adapters/client.py b/langchain_mcp_adapters/client.py index 1fcbdd9..a3a609a 100644 --- a/langchain_mcp_adapters/client.py +++ b/langchain_mcp_adapters/client.py @@ -5,15 +5,16 @@ """ import asyncio -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Callable from contextlib import asynccontextmanager from types import TracebackType from typing import Any from langchain_core.documents.base import Blob from langchain_core.messages import AIMessage, HumanMessage -from langchain_core.tools import BaseTool +from langchain_core.tools import BaseTool, ToolException from mcp import ClientSession +from pydantic import ValidationError from langchain_mcp_adapters.prompts import load_mcp_prompt from langchain_mcp_adapters.resources import load_mcp_resources @@ -125,12 +126,22 @@ async def session( await session.initialize() yield session - async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]: + async def get_tools( + self, + *, + server_name: str | None = None, + handle_tool_error: bool | str | Callable[[ToolException], str] | None = False, + handle_validation_error: ( + bool | str | Callable[[ValidationError], str] | None + ) = False, + ) -> list[BaseTool]: """Get a list of all tools from all connected servers. Args: server_name: Optional name of the server to get tools from. If None, all tools from all servers will be returned (default). + handle_tool_error: Optional error handler for tool execution errors. + handle_validation_error: Optional error handler for validation errors. NOTE: a new session will be created for each tool call @@ -145,13 +156,23 @@ async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]: f"expected one of '{list(self.connections.keys())}'" ) raise ValueError(msg) - return await load_mcp_tools(None, connection=self.connections[server_name]) + return await load_mcp_tools( + None, + connection=self.connections[server_name], + handle_tool_error=handle_tool_error, + handle_validation_error=handle_validation_error, + ) all_tools: list[BaseTool] = [] load_mcp_tool_tasks = [] for connection in self.connections.values(): load_mcp_tool_task = asyncio.create_task( - load_mcp_tools(None, connection=connection) + load_mcp_tools( + None, + connection=connection, + handle_tool_error=handle_tool_error, + handle_validation_error=handle_validation_error, + ) ) load_mcp_tool_tasks.append(load_mcp_tool_task) tools_list = await asyncio.gather(*load_mcp_tool_tasks) diff --git a/langchain_mcp_adapters/tools.py b/langchain_mcp_adapters/tools.py index 0c6fa45..8f76eef 100644 --- a/langchain_mcp_adapters/tools.py +++ b/langchain_mcp_adapters/tools.py @@ -4,6 +4,7 @@ tools, handle tool execution, and manage tool conversion between the two formats. """ +from collections.abc import Callable from typing import Any, cast, get_args from langchain_core.tools import ( @@ -18,7 +19,7 @@ from mcp.server.fastmcp.utilities.func_metadata import ArgModelBase, FuncMetadata from mcp.types import CallToolResult, EmbeddedResource, ImageContent, TextContent from mcp.types import Tool as MCPTool -from pydantic import BaseModel, create_model +from pydantic import BaseModel, ValidationError, create_model from langchain_mcp_adapters.sessions import Connection, create_session @@ -102,6 +103,10 @@ def convert_mcp_tool_to_langchain_tool( tool: MCPTool, *, connection: Connection | None = None, + handle_tool_error: bool | str | Callable[[ToolException], str] | None = False, + handle_validation_error: ( + bool | str | Callable[[ValidationError], str] | None + ) = False, ) -> BaseTool: """Convert an MCP tool to a LangChain tool. @@ -112,6 +117,8 @@ def convert_mcp_tool_to_langchain_tool( tool: MCP tool to convert connection: Optional connection config to use to create a new session if a `session` is not provided + handle_tool_error: Optional error handler for tool execution errors. + handle_validation_error: Optional error handler for validation errors. Returns: a LangChain tool @@ -143,6 +150,8 @@ async def call_tool( coroutine=call_tool, response_format="content_and_artifact", metadata=tool.annotations.model_dump() if tool.annotations else None, + handle_tool_error=handle_tool_error, + handle_validation_error=handle_validation_error, ) @@ -150,12 +159,18 @@ async def load_mcp_tools( session: ClientSession | None, *, connection: Connection | None = None, + handle_tool_error: bool | str | Callable[[ToolException], str] | None = False, + handle_validation_error: ( + bool | str | Callable[[ValidationError], str] | None + ) = False, ) -> list[BaseTool]: """Load all available MCP tools and convert them to LangChain tools. Args: session: The MCP client session. If None, connection must be provided. connection: Connection config to create a new session if session is None. + handle_tool_error: Optional error handler for tool execution errors. + handle_validation_error: Optional error handler for validation errors. Returns: List of LangChain tools. Tool annotations are returned as part @@ -177,7 +192,13 @@ async def load_mcp_tools( tools = await _list_all_tools(session) return [ - convert_mcp_tool_to_langchain_tool(session, tool, connection=connection) + convert_mcp_tool_to_langchain_tool( + session, + tool, + connection=connection, + handle_tool_error=handle_tool_error, + handle_validation_error=handle_validation_error, + ) for tool in tools ]