diff --git a/langchain_mcp_adapters/tools.py b/langchain_mcp_adapters/tools.py index 5f33881..c9187fe 100644 --- a/langchain_mcp_adapters/tools.py +++ b/langchain_mcp_adapters/tools.py @@ -4,8 +4,10 @@ tools, handle tool execution, and manage tool conversion between the two formats. """ +from copy import deepcopy from typing import Any, cast, get_args +from langchain_core.runnables import RunnableConfig from langchain_core.tools import ( BaseTool, InjectedToolArg, @@ -97,6 +99,38 @@ async def _list_all_tools(session: ClientSession) -> list[MCPTool]: return all_tools +def _merge_headers_into_connections(conn: Connection, headers: dict[str, str]) -> Connection: + """ + Returns a copy of 'coon' with 'headers' merged when applicable. + Supports 'sse' aand 'streamable_http' transports, which already accept 'headers' in their factories + + Args: + conn: A connection config + headers: Headers to merge into the connection + + Returns: + A new connection config with merged headers, if applicable + """ + # If no headers provided, returns conn + if not headers: + return conn + + new_conn = deepcopy(conn) + transport = conn.get("transport") + + # Check transport type + if transport in ["sse", "streamable_http"]: + # Get current headers in connection + current_headers = new_conn.get("headers") + merged = { + **current_headers, + **{key: str(value) for key, value in headers.items()} + } + new_conn["headers"] = merged + + return new_conn + + def convert_mcp_tool_to_langchain_tool( session: ClientSession | None, tool: MCPTool, @@ -122,11 +156,29 @@ def convert_mcp_tool_to_langchain_tool( raise ValueError(msg) async def call_tool( + config: RunnableConfig = None, **arguments: dict[str, Any], ) -> tuple[str | list[str], list[NonTextContent] | None]: + configurable: dict = config.get("configurable", {}) + + dynamic_headers: dict[str, str] = {} + if configurable: + # By default, 'jwt' is injected into Authorization header + jwt = configurable.get("jwt", None) + + if isinstance(jwt, str) and jwt.strip(): + dynamic_headers["Authorization"] = f"Bearer {jwt}" + + # All other headers + extra = configurable.get("mcp_headers") + if isinstance(extra, dict): + for key, value in extra.items(): + dynamic_headers[key] = str(value) + if session is None: # If a session is not provided, we will create one on the fly - async with create_session(connection) as tool_session: + conn_with_headers = _merge_headers_into_connections(connection, dynamic_headers) + async with create_session(conn_with_headers) as tool_session: await tool_session.initialize() call_tool_result = await cast("ClientSession", tool_session).call_tool( tool.name,