Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 74 additions & 1 deletion langchain_mcp_adapters/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Literal, Protocol

import anyio
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.server import Server
from mcp.server.fastmcp import FastMCP as FastMCP1Server
from mcp.shared.memory import create_client_server_memory_streams
from typing_extensions import NotRequired, TypedDict

if TYPE_CHECKING:
Expand Down Expand Up @@ -183,8 +187,31 @@ class WebsocketConnection(TypedDict):
"""Additional keyword arguments to pass to the ClientSession"""


class InMemoryConnection(TypedDict):
"""Configuration for In-memory transport connections to MCP servers."""

transport: Literal["in_memory"]

server: Server[Any] | FastMCP1Server
"""The Server instance to connect to."""

raise_exceptions: NotRequired[bool]
"""When False, exceptions are returned as messages to the client.
When True, exceptions are raised, which will cause the server to shut down
but also make tracing exceptions much easier during testing and when using
in-process servers.
"""

session_kwargs: NotRequired[dict[str, Any] | None]
"""Additional keyword arguments to pass to the ClientSession"""


Connection = (
StdioConnection | SSEConnection | StreamableHttpConnection | WebsocketConnection
StdioConnection
| SSEConnection
| StreamableHttpConnection
| WebsocketConnection
| InMemoryConnection
)


Expand Down Expand Up @@ -234,6 +261,46 @@ async def _create_stdio_session(
yield session


@asynccontextmanager
async def _create_inmemory_session(
*,
server: Server[Any] | FastMCP1Server,
raise_exceptions: bool = False,
session_kwargs: dict[str, Any] | None = None,
) -> AsyncIterator[ClientSession]:
async with create_client_server_memory_streams() as (
client_streams,
server_streams,
):
if isinstance(server, FastMCP1Server):
server = server._mcp_server # type: ignore[reportPrivateUsage]

# https://github.com/jlowin/fastmcp/pull/758
client_read, client_write = client_streams
server_read, server_write = server_streams

# Create a cancel scope for the server task
async with anyio.create_task_group() as tg:
tg.start_soon(
lambda: server.run(
server_read,
server_write,
server.create_initialization_options(),
raise_exceptions=raise_exceptions,
)
)

try:
async with ClientSession(
client_read,
client_write,
**(session_kwargs or {}),
) as client_session:
yield client_session
finally:
tg.cancel_scope.cancel()


@asynccontextmanager
async def _create_sse_session(
*,
Expand Down Expand Up @@ -423,6 +490,12 @@ async def create_session(
raise ValueError(msg)
async with _create_websocket_session(**params) as session:
yield session
elif transport == "in_memory":
if "server" not in params:
msg = "'server' parameter is required for In-memory connection"
raise ValueError(msg)
async with _create_inmemory_session(**params) as session:
yield session
else:
msg = (
f"Unsupported transport: {transport}. "
Expand Down
114 changes: 114 additions & 0 deletions tests/test_inmemory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import importlib.util
import os
from pathlib import Path

from langchain_core.messages import AIMessage
from langchain_core.tools import BaseTool

from langchain_mcp_adapters.client import MultiServerMCPClient


def _load_module(module_name: str, server_path: str) -> any:
module_spec = importlib.util.spec_from_file_location(module_name, server_path)
assert module_spec is not None

module = importlib.util.module_from_spec(module_spec)
module_spec.loader.exec_module(module)
return module


async def test_multi_server_mcp_client(
socket_enabled,
websocket_server,
websocket_server_port: int,
):
"""Test that MultiServerMCPClient can connect to multiple servers and load tools."""
# Get the absolute path to the server scripts
current_dir = Path(__file__).parent
math_server_path = os.path.join(current_dir, "servers/math_server.py")
weather_server_path = os.path.join(current_dir, "servers/weather_server.py")
# import weather_server
weather_server_module = _load_module("weather_server", weather_server_path)

client = MultiServerMCPClient(
{
"math": {
"command": "python3",
"args": [math_server_path],
"transport": "stdio",
},
"weather": {
"server": weather_server_module.mcp,
"transport": "in_memory",
},
},
)
# Check that we have tools from both servers
all_tools = await client.get_tools()

# Should have 3 tools (add, multiply, get_weather)
assert len(all_tools) == 3

# Check that tools are BaseTool instances
for tool in all_tools:
assert isinstance(tool, BaseTool)

# Verify tool names
tool_names = {tool.name for tool in all_tools}
assert tool_names == {"add", "multiply", "get_weather"}

# Check math server tools
math_tools = await client.get_tools(server_name="math")
assert len(math_tools) == 2
math_tool_names = {tool.name for tool in math_tools}
assert math_tool_names == {"add", "multiply"}

# Check weather server tools
weather_tools = await client.get_tools(server_name="weather")
assert len(weather_tools) == 1
assert weather_tools[0].name == "get_weather"

# Test that we can call a math tool
add_tool = next(tool for tool in all_tools if tool.name == "add")
result = await add_tool.ainvoke({"a": 2, "b": 3})
assert result == "5"

# Test that we can call a weather tool
weather_tool = next(tool for tool in all_tools if tool.name == "get_weather")
result = await weather_tool.ainvoke({"location": "London"})
assert result == "It's always sunny in London"

# Test the multiply tool
multiply_tool = next(tool for tool in all_tools if tool.name == "multiply")
result = await multiply_tool.ainvoke({"a": 4, "b": 5})
assert result == "20"


async def test_get_prompt():
"""Test retrieving prompts from MCP servers."""
# Get the absolute path to the server scripts
current_dir = Path(__file__).parent
math_server_path = os.path.join(current_dir, "servers/math_server.py")
# import weather_server
math_server_module = _load_module("math_server", math_server_path)

client = MultiServerMCPClient(
{
"math": {
"server": math_server_module.mcp,
"transport": "in_memory",
}
},
)
# Test getting a prompt from the math server
messages = await client.get_prompt(
"math",
"configure_assistant",
arguments={"skills": "math, addition, multiplication"},
)

# Check that we got an AIMessage back
assert len(messages) == 1
assert isinstance(messages[0], AIMessage)
assert "You are a helpful assistant" in messages[0].content
assert "math, addition, multiplication" in messages[0].content