Skip to content

Commit 314b72f

Browse files
authored
Merge pull request #9 from DiTo97/patch-1
support for MCP servers
2 parents ce93bc1 + a6592c7 commit 314b72f

File tree

16 files changed

+2010
-772
lines changed

16 files changed

+2010
-772
lines changed

coverage-badge.svg

Lines changed: 1 addition & 1 deletion
Loading

pyproject.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,27 @@ dependencies = [
2323
"absl-py",
2424
"asyncio",
2525
"click",
26+
"click",
27+
"click",
2628
"datasets",
2729
"docstring-parser",
2830
"graphviz",
2931
"inquirer",
32+
"docstring-parser",
33+
"graphviz",
34+
"inquirer",
3035
"jinja2",
3136
"kuzu",
3237
"litellm",
3338
"matplotlib",
39+
"mcp>1.9.2",
40+
"matplotlib",
41+
"mcp>1.9.2",
3442
"namex",
3543
"neo4j",
3644
"nest-asyncio",
45+
"neo4j",
46+
"nest-asyncio",
3747
"numpy",
3848
"optree",
3949
"pydantic",
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import contextlib
2+
import multiprocessing
3+
import socket
4+
import time
5+
from collections.abc import Generator
6+
7+
import uvicorn
8+
from mcp.server.fastmcp import FastMCP
9+
10+
11+
def run_streamable_server(server: FastMCP, server_port: int) -> None:
12+
"""Run a FastMCP server in a separate process exposing a streamable HTTP endpoint."""
13+
app = server.streamable_http_app()
14+
uvicorn_server = uvicorn.Server(
15+
config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")
16+
)
17+
uvicorn_server.run()
18+
19+
20+
@contextlib.contextmanager
21+
def run_streamable_server_multiprocessing(server: FastMCP) -> Generator[None, None, None]:
22+
"""Run the server in a separate process exposing a streamable HTTP endpoint.
23+
24+
The endpoint will be available at `http://localhost:{server.settings.port}/mcp/`.
25+
"""
26+
proc = multiprocessing.Process(
27+
target=run_streamable_server,
28+
kwargs={"server": server, "server_port": server.settings.port},
29+
daemon=True,
30+
)
31+
proc.start()
32+
33+
# Wait for server to be running
34+
max_attempts = 20
35+
attempt = 0
36+
37+
while attempt < max_attempts:
38+
try:
39+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
40+
s.connect(("127.0.0.1", server.settings.port))
41+
break
42+
except ConnectionRefusedError:
43+
time.sleep(0.1)
44+
attempt += 1
45+
else:
46+
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
47+
48+
try:
49+
yield
50+
finally:
51+
# Signal the server to stop
52+
proc.kill()
53+
proc.join(timeout=2)
54+
if proc.is_alive():
55+
raise RuntimeError("Server process is still alive after attempting to terminate it")

synalinks/src/utils/mcp/client.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import asyncio
2+
from contextlib import asynccontextmanager
3+
from types import TracebackType
4+
from typing import Any, AsyncIterator
5+
6+
from mcp import ClientSession
7+
8+
from synalinks import ChatMessages, GenericOutputs
9+
from synalinks.src.api_export import synalinks_export
10+
from synalinks.src.utils.mcp.prompts import load_mcp_prompt
11+
from synalinks.src.utils.mcp.resources import load_mcp_resources
12+
from synalinks.src.utils.mcp.sessions import (
13+
ClientSession,
14+
Connection,
15+
McpHttpClientFactory,
16+
SSEConnection,
17+
StdioConnection,
18+
StreamableHttpConnection,
19+
WebsocketConnection,
20+
create_session,
21+
)
22+
from synalinks.src.utils.mcp.tools import load_mcp_tools
23+
from synalinks.src.utils.tool_utils import Tool
24+
25+
26+
ASYNC_CONTEXT_MANAGER_ERROR = (
27+
"MultiServerMCPClient cannot be used as a context manager (e.g., async with MultiServerMCPClient(...)). "
28+
"Instead, you can do one of the following:\n"
29+
"1. client = MultiServerMCPClient(...)\n"
30+
" tools = await client.get_tools()\n"
31+
"2. client = MultiServerMCPClient(...)\n"
32+
" async with client.session(server_name) as session:\n"
33+
" tools = await load_mcp_tools(session)"
34+
)
35+
36+
37+
@synalinks_export(
38+
[
39+
"synalinks.MultiServerMCPClient",
40+
]
41+
)
42+
class MultiServerMCPClient:
43+
"""Client for connecting to multiple MCP servers and loading Synalinks-compatible tools, prompts and resources from them."""
44+
45+
def __init__(
46+
self,
47+
connections: dict[str, Connection] | None = None,
48+
) -> None:
49+
"""Initialize a MultiServerMCPClient with MCP servers connections.
50+
51+
Args:
52+
connections: A dictionary mapping server names to connection configurations.
53+
If None, no initial connections are established.
54+
55+
Example: basic usage (starting a new session on each tool call)
56+
57+
```python
58+
import synalinks
59+
60+
client = synalinks.MultiServerMCPClient(
61+
{
62+
"math": {
63+
"command": "python",
64+
# Make sure to update to the full absolute path to your math_server.py file
65+
"args": ["/path/to/math_server.py"],
66+
"transport": "stdio",
67+
},
68+
"weather": {
69+
# Make sure you start your weather server on port 8000
70+
"url": "http://localhost:8000/mcp",
71+
"transport": "streamable_http",
72+
}
73+
}
74+
)
75+
all_tools = await client.get_tools()
76+
```
77+
78+
Example: explicitly starting a session
79+
80+
```python
81+
import synalinks
82+
from synalinks.src.utils.mcp.tools import load_mcp_tools
83+
84+
client = synalinks.MultiServerMCPClient({...})
85+
async with client.session("math") as session:
86+
tools = await load_mcp_tools(session)
87+
```
88+
"""
89+
connections = connections or {}
90+
91+
if connections:
92+
assert len(set(connections.keys())) == len(connections), (
93+
"MCP server names in the connections mapping must be unique."
94+
)
95+
96+
self.connections: dict[str, Connection] = connections
97+
98+
@asynccontextmanager
99+
async def session(
100+
self,
101+
server_name: str,
102+
*,
103+
auto_initialize: bool = True,
104+
) -> AsyncIterator[ClientSession]:
105+
"""Connect to an MCP server and initialize a session.
106+
107+
Args:
108+
server_name: Name to identify this server connection
109+
auto_initialize: Whether to automatically initialize the session
110+
111+
Raises:
112+
ValueError: If the server name is not found in the connections
113+
114+
Yields:
115+
An initialized ClientSession
116+
"""
117+
if server_name not in self.connections:
118+
raise ValueError(
119+
f"Couldn't find a server with name '{server_name}', expected one of '{list(self.connections.keys())}'"
120+
)
121+
122+
async with create_session(self.connections[server_name]) as session:
123+
if auto_initialize:
124+
await session.initialize()
125+
yield session
126+
127+
async def get_tools(self, *, server_name: str | None = None) -> list[Tool]:
128+
"""Get a list of all tools from all connected servers.
129+
130+
Args:
131+
server_name: Optional name of the server to get tools from.
132+
If None, all tools from all servers will be returned (default).
133+
134+
NOTE: a new session will be created for each tool call
135+
136+
Returns:
137+
A list of Synalinks tools
138+
"""
139+
if server_name is not None:
140+
if server_name not in self.connections:
141+
raise ValueError(
142+
f"Couldn't find a server with name '{server_name}', expected one of '{list(self.connections.keys())}'"
143+
)
144+
return await load_mcp_tools(None, connection=self.connections[server_name])
145+
146+
all_tools: list[Tool] = []
147+
load_mcp_tool_tasks = []
148+
for namespace, connection in self.connections.items():
149+
load_mcp_tool_task = asyncio.create_task(load_mcp_tools(None, connection=connection, namespace=namespace))
150+
load_mcp_tool_tasks.append(load_mcp_tool_task)
151+
tools_list = await asyncio.gather(*load_mcp_tool_tasks)
152+
for tools in tools_list:
153+
all_tools.extend(tools)
154+
return all_tools
155+
156+
async def get_prompt(
157+
self, server_name: str, prompt_name: str, *, arguments: dict[str, Any] | None = None
158+
) -> ChatMessages:
159+
"""Get a prompt from a given MCP server."""
160+
async with self.session(server_name) as session:
161+
prompt = await load_mcp_prompt(session, prompt_name, arguments=arguments)
162+
return prompt
163+
164+
async def get_resources(
165+
self, server_name: str, *, uris: str | list[str] | None = None
166+
) -> list[GenericOutputs]:
167+
"""Get resources from a given MCP server.
168+
169+
Args:
170+
server_name: Name of the server to get resources from
171+
uris: Optional resource URI or list of URIs to load. If not provided, all resources will be loaded.
172+
173+
Returns:
174+
A list of Synalinks GenericOutputs resources
175+
"""
176+
async with self.session(server_name) as session:
177+
resources = await load_mcp_resources(session, uris=uris)
178+
return resources
179+
180+
async def __aenter__(self) -> "MultiServerMCPClient":
181+
raise NotImplementedError(ASYNC_CONTEXT_MANAGER_ERROR)
182+
183+
def __aexit__(
184+
self,
185+
exc_type: type[BaseException] | None,
186+
exc_val: BaseException | None,
187+
exc_tb: TracebackType | None,
188+
) -> None:
189+
raise NotImplementedError(ASYNC_CONTEXT_MANAGER_ERROR)
190+
191+
192+
__all__ = [
193+
"MultiServerMCPClient",
194+
"McpHttpClientFactory",
195+
"SSEConnection",
196+
"StdioConnection",
197+
"StreamableHttpConnection",
198+
"WebsocketConnection",
199+
]

0 commit comments

Comments
 (0)