diff --git a/langchain_mcp_adapters/sessions.py b/langchain_mcp_adapters/sessions.py index 1c67b8c..25dc8fb 100644 --- a/langchain_mcp_adapters/sessions.py +++ b/langchain_mcp_adapters/sessions.py @@ -10,10 +10,14 @@ from datetime import timedelta from typing import TYPE_CHECKING, Any, Literal, Protocol +import anyio from mcp import ClientSession, StdioServerParameters from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client from mcp.client.streamable_http import streamablehttp_client +from mcp.server import Server +from mcp.server.fastmcp import FastMCP as FastMCP1Server +from mcp.shared.memory import create_client_server_memory_streams from typing_extensions import NotRequired, TypedDict if TYPE_CHECKING: @@ -183,8 +187,31 @@ class WebsocketConnection(TypedDict): """Additional keyword arguments to pass to the ClientSession""" +class InMemoryConnection(TypedDict): + """Configuration for In-memory transport connections to MCP servers.""" + + transport: Literal["in_memory"] + + server: Server[Any] | FastMCP1Server + """The Server instance to connect to.""" + + raise_exceptions: NotRequired[bool] + """When False, exceptions are returned as messages to the client. + When True, exceptions are raised, which will cause the server to shut down + but also make tracing exceptions much easier during testing and when using + in-process servers. + """ + + session_kwargs: NotRequired[dict[str, Any] | None] + """Additional keyword arguments to pass to the ClientSession""" + + Connection = ( - StdioConnection | SSEConnection | StreamableHttpConnection | WebsocketConnection + StdioConnection + | SSEConnection + | StreamableHttpConnection + | WebsocketConnection + | InMemoryConnection ) @@ -234,6 +261,46 @@ async def _create_stdio_session( yield session +@asynccontextmanager +async def _create_inmemory_session( + *, + server: Server[Any] | FastMCP1Server, + raise_exceptions: bool = False, + session_kwargs: dict[str, Any] | None = None, +) -> AsyncIterator[ClientSession]: + async with create_client_server_memory_streams() as ( + client_streams, + server_streams, + ): + if isinstance(server, FastMCP1Server): + server = server._mcp_server # type: ignore[reportPrivateUsage] + + # https://github.com/jlowin/fastmcp/pull/758 + client_read, client_write = client_streams + server_read, server_write = server_streams + + # Create a cancel scope for the server task + async with anyio.create_task_group() as tg: + tg.start_soon( + lambda: server.run( + server_read, + server_write, + server.create_initialization_options(), + raise_exceptions=raise_exceptions, + ) + ) + + try: + async with ClientSession( + client_read, + client_write, + **(session_kwargs or {}), + ) as client_session: + yield client_session + finally: + tg.cancel_scope.cancel() + + @asynccontextmanager async def _create_sse_session( *, @@ -423,6 +490,12 @@ async def create_session( raise ValueError(msg) async with _create_websocket_session(**params) as session: yield session + elif transport == "in_memory": + if "server" not in params: + msg = "'server' parameter is required for In-memory connection" + raise ValueError(msg) + async with _create_inmemory_session(**params) as session: + yield session else: msg = ( f"Unsupported transport: {transport}. " diff --git a/tests/test_inmemory.py b/tests/test_inmemory.py new file mode 100644 index 0000000..6ed7289 --- /dev/null +++ b/tests/test_inmemory.py @@ -0,0 +1,114 @@ +import importlib.util +import os +from pathlib import Path + +from langchain_core.messages import AIMessage +from langchain_core.tools import BaseTool + +from langchain_mcp_adapters.client import MultiServerMCPClient + + +def _load_module(module_name: str, server_path: str) -> any: + module_spec = importlib.util.spec_from_file_location(module_name, server_path) + assert module_spec is not None + + module = importlib.util.module_from_spec(module_spec) + module_spec.loader.exec_module(module) + return module + + +async def test_multi_server_mcp_client( + socket_enabled, + websocket_server, + websocket_server_port: int, +): + """Test that MultiServerMCPClient can connect to multiple servers and load tools.""" + # Get the absolute path to the server scripts + current_dir = Path(__file__).parent + math_server_path = os.path.join(current_dir, "servers/math_server.py") + weather_server_path = os.path.join(current_dir, "servers/weather_server.py") + # import weather_server + weather_server_module = _load_module("weather_server", weather_server_path) + + client = MultiServerMCPClient( + { + "math": { + "command": "python3", + "args": [math_server_path], + "transport": "stdio", + }, + "weather": { + "server": weather_server_module.mcp, + "transport": "in_memory", + }, + }, + ) + # Check that we have tools from both servers + all_tools = await client.get_tools() + + # Should have 3 tools (add, multiply, get_weather) + assert len(all_tools) == 3 + + # Check that tools are BaseTool instances + for tool in all_tools: + assert isinstance(tool, BaseTool) + + # Verify tool names + tool_names = {tool.name for tool in all_tools} + assert tool_names == {"add", "multiply", "get_weather"} + + # Check math server tools + math_tools = await client.get_tools(server_name="math") + assert len(math_tools) == 2 + math_tool_names = {tool.name for tool in math_tools} + assert math_tool_names == {"add", "multiply"} + + # Check weather server tools + weather_tools = await client.get_tools(server_name="weather") + assert len(weather_tools) == 1 + assert weather_tools[0].name == "get_weather" + + # Test that we can call a math tool + add_tool = next(tool for tool in all_tools if tool.name == "add") + result = await add_tool.ainvoke({"a": 2, "b": 3}) + assert result == "5" + + # Test that we can call a weather tool + weather_tool = next(tool for tool in all_tools if tool.name == "get_weather") + result = await weather_tool.ainvoke({"location": "London"}) + assert result == "It's always sunny in London" + + # Test the multiply tool + multiply_tool = next(tool for tool in all_tools if tool.name == "multiply") + result = await multiply_tool.ainvoke({"a": 4, "b": 5}) + assert result == "20" + + +async def test_get_prompt(): + """Test retrieving prompts from MCP servers.""" + # Get the absolute path to the server scripts + current_dir = Path(__file__).parent + math_server_path = os.path.join(current_dir, "servers/math_server.py") + # import weather_server + math_server_module = _load_module("math_server", math_server_path) + + client = MultiServerMCPClient( + { + "math": { + "server": math_server_module.mcp, + "transport": "in_memory", + } + }, + ) + # Test getting a prompt from the math server + messages = await client.get_prompt( + "math", + "configure_assistant", + arguments={"skills": "math, addition, multiplication"}, + ) + + # Check that we got an AIMessage back + assert len(messages) == 1 + assert isinstance(messages[0], AIMessage) + assert "You are a helpful assistant" in messages[0].content + assert "math, addition, multiplication" in messages[0].content