Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion langchain_mcp_adapters/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
tools, handle tool execution, and manage tool conversion between the two formats.
"""

from copy import deepcopy
from typing import Any, cast, get_args

from langchain_core.runnables import RunnableConfig
from langchain_core.tools import (
BaseTool,
InjectedToolArg,
Expand Down Expand Up @@ -97,6 +99,38 @@ async def _list_all_tools(session: ClientSession) -> list[MCPTool]:
return all_tools


def _merge_headers_into_connections(conn: Connection, headers: dict[str, str]) -> Connection:
"""
Returns a copy of 'coon' with 'headers' merged when applicable.
Supports 'sse' aand 'streamable_http' transports, which already accept 'headers' in their factories

Args:
conn: A connection config
headers: Headers to merge into the connection

Returns:
A new connection config with merged headers, if applicable
"""
# If no headers provided, returns conn
if not headers:
return conn

new_conn = deepcopy(conn)
transport = conn.get("transport")

# Check transport type
if transport in ["sse", "streamable_http"]:
# Get current headers in connection
current_headers = new_conn.get("headers")
merged = {
**current_headers,
**{key: str(value) for key, value in headers.items()}
}
new_conn["headers"] = merged

return new_conn


def convert_mcp_tool_to_langchain_tool(
session: ClientSession | None,
tool: MCPTool,
Expand All @@ -122,11 +156,29 @@ def convert_mcp_tool_to_langchain_tool(
raise ValueError(msg)

async def call_tool(
config: RunnableConfig = None,
**arguments: dict[str, Any],
) -> tuple[str | list[str], list[NonTextContent] | None]:
configurable: dict = config.get("configurable", {})

dynamic_headers: dict[str, str] = {}
if configurable:
# By default, 'jwt' is injected into Authorization header
jwt = configurable.get("jwt", None)

if isinstance(jwt, str) and jwt.strip():
dynamic_headers["Authorization"] = f"Bearer {jwt}"

# All other headers
extra = configurable.get("mcp_headers")
if isinstance(extra, dict):
for key, value in extra.items():
dynamic_headers[key] = str(value)

if session is None:
# If a session is not provided, we will create one on the fly
async with create_session(connection) as tool_session:
conn_with_headers = _merge_headers_into_connections(connection, dynamic_headers)
async with create_session(conn_with_headers) as tool_session:
await tool_session.initialize()
call_tool_result = await cast("ClientSession", tool_session).call_tool(
tool.name,
Expand Down