Skip to content

Commit b5cfb19

Browse files
committed
refactor: decouple transport and session
1 parent 0874ccc commit b5cfb19

File tree

5 files changed

+362
-1
lines changed

5 files changed

+362
-1
lines changed

src/mcpm/core/router/app.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import asyncio
2+
import logging
3+
from contextlib import asynccontextmanager
4+
5+
from starlette.applications import Starlette
6+
from starlette.middleware import Middleware
7+
from starlette.requests import Request
8+
from starlette.responses import Response
9+
from starlette.routing import Mount, Route
10+
from starlette.types import Receive, Scope, Send
11+
12+
from mcpm.monitor.event import monitor
13+
from mcpm.router.router import MCPRouter
14+
15+
from .middleware import SessionMiddleware
16+
from .session import SessionManager
17+
from .transport import SseTransport
18+
19+
logger = logging.getLogger("mcpm.router")
20+
21+
session_manager = SessionManager()
22+
transport = SseTransport(endpoint="/messages/", session_manager=session_manager)
23+
24+
router = MCPRouter(reload_server=False)
25+
26+
class NoOpsResponse(Response):
27+
async def __call__(self, scope: Scope, receive: Receive, send: Send):
28+
# To comply with Starlette's ASGI application design, this method must return a response.
29+
# Since no further client interaction is needed after server shutdown, we provide a no-operation response
30+
# that allows the application to exit gracefully when cancelled by Uvicorn.
31+
# No content is sent back to the client as EventSourceResponse has already returned a 200 status code.
32+
pass
33+
34+
35+
async def handle_sse(request: Request):
36+
try:
37+
async with transport.connect_sse(request.scope, request.receive, request._send) as (read, write):
38+
await router.aggregated_server.run(read, write, router.aggregated_server.initialization_options) # type: ignore
39+
except asyncio.CancelledError:
40+
return NoOpsResponse()
41+
42+
43+
@asynccontextmanager
44+
async def lifespan(app):
45+
logger.info("Starting MCPRouter...")
46+
await router.initialize_router()
47+
await monitor.initialize_storage()
48+
49+
yield
50+
51+
logger.info("Shutting down MCPRouter...")
52+
await router.shutdown()
53+
await monitor.close()
54+
55+
app = Starlette(
56+
debug=True,
57+
routes=[
58+
Route("/sse", endpoint=handle_sse),
59+
Mount("/messages/", app=transport.handle_post_message)
60+
],
61+
middleware=[Middleware(SessionMiddleware, session_manager=session_manager)],
62+
lifespan=lifespan
63+
)

src/mcpm/core/router/middleware.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import logging
2+
from typing import Any
3+
from uuid import UUID
4+
5+
from starlette.requests import Request
6+
from starlette.responses import Response
7+
from starlette.types import ASGIApp, Receive, Scope, Send
8+
9+
from .session import SessionManager
10+
11+
logger = logging.getLogger(__name__)
12+
13+
META_DATA_KEYS = ["profile", "client"]
14+
15+
def get_meta(request: Request) -> dict[str, Any]:
16+
meta: dict[str, Any] = {}
17+
for key in META_DATA_KEYS:
18+
value = request.query_params.get(key)
19+
if not value:
20+
value = request.headers.get(key)
21+
22+
if value:
23+
meta[key] = value
24+
25+
return meta
26+
27+
class SessionMiddleware:
28+
29+
def __init__(self, app: ASGIApp, session_manager: SessionManager) -> None:
30+
self.app = app
31+
self.session_manager = session_manager
32+
33+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
34+
# we related metadata with session through this middleware, so that in the transport layer we only need to handle
35+
# session_id and dispatch message to the correct memory stream
36+
37+
if scope["type"] != "http":
38+
await self.app(scope, receive, send)
39+
return
40+
41+
request = Request(scope)
42+
43+
if scope["path"] == "/sse":
44+
# retrieve metadata from query params or header
45+
session = await self.session_manager.create_session(meta=get_meta(request))
46+
logger.debug(f"Created new session with ID: {session['id']}")
47+
48+
scope["session_id"] = session["id"].hex
49+
50+
if scope["path"] == "/messages/":
51+
session_id_param = request.query_params.get("session_id")
52+
if not session_id_param:
53+
logger.debug("Missing session_id")
54+
response = Response("session_id is required", status_code=400)
55+
await response(scope, receive, send)
56+
return
57+
58+
# validate session_id
59+
try:
60+
session_id = UUID(hex=session_id_param)
61+
except ValueError:
62+
logger.warning(f"Received invalid session ID: {session_id_param}")
63+
response = Response("invalid session ID", status_code=400)
64+
await response(scope, receive, send)
65+
return
66+
67+
# if session_id is not in session manager, return 404
68+
if not self.session_manager.exist(session_id):
69+
logger.debug(f"session {session_id} not found")
70+
response = Response("session not found", status_code=404)
71+
await response(scope, receive, send)
72+
return
73+
74+
scope["session_id"] = session_id.hex
75+
76+
await self.app(scope, receive, send)

src/mcpm/core/router/session.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from typing import Any, Protocol, TypedDict
2+
from uuid import UUID, uuid4
3+
4+
import anyio
5+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
6+
from mcp import types
7+
8+
9+
class Session(TypedDict):
10+
id: UUID
11+
# some read,write streams related with session
12+
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
13+
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
14+
15+
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
16+
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
17+
# any meta data is saved here
18+
meta: dict[str, Any]
19+
20+
21+
class SessionStore(Protocol):
22+
23+
def exist(self, session_id: UUID) -> bool:
24+
...
25+
26+
async def put(self, session: Session) -> None:
27+
...
28+
29+
async def get(self, session_id: UUID) -> Session:
30+
...
31+
32+
async def remove(self, session_id: UUID):
33+
...
34+
35+
async def cleanup(self):
36+
...
37+
38+
39+
class LocalSessionStore:
40+
41+
def __init__(self):
42+
self._store: dict[UUID, Session] = {}
43+
44+
def exist(self, session_id: UUID) -> bool:
45+
return session_id in self._store
46+
47+
async def put(self, session: Session) -> None:
48+
self._store[session["id"]] = session
49+
50+
async def get(self, session_id: UUID) -> Session:
51+
return self._store[session_id]
52+
53+
async def remove(self, session_id: UUID):
54+
session = self._store.pop(session_id, None)
55+
if session:
56+
await session["read_stream_writer"].aclose()
57+
await session["write_stream"].aclose()
58+
59+
async def cleanup(self):
60+
keys = list(self._store.keys())
61+
for session_id in keys:
62+
await self.remove(session_id)
63+
64+
65+
class SessionManager:
66+
67+
def __init__(self):
68+
self.session_store: SessionStore = LocalSessionStore()
69+
70+
def exist(self, session_id: UUID) -> bool:
71+
return self.session_store.exist(session_id)
72+
73+
async def create_session(self, meta: dict[str, Any] = {}) -> Session:
74+
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
75+
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
76+
session_id = uuid4()
77+
session = Session(
78+
id=session_id,
79+
read_stream=read_stream,
80+
read_stream_writer=read_stream_writer,
81+
write_stream=write_stream,
82+
write_stream_reader=write_stream_reader,
83+
meta=meta
84+
)
85+
await self.session_store.put(session)
86+
return session
87+
88+
async def get_session(self, session_id: UUID) -> Session:
89+
return await self.session_store.get(session_id)
90+
91+
async def close_session(self, session_id: UUID):
92+
await self.session_store.remove(session_id)
93+
94+
async def cleanup_resources(self):
95+
await self.session_store.cleanup()

src/mcpm/core/router/transport.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import asyncio
2+
import json
3+
import logging
4+
from contextlib import asynccontextmanager
5+
from typing import Any
6+
from urllib.parse import quote
7+
from uuid import UUID
8+
9+
import anyio
10+
from mcp import types
11+
from pydantic import ValidationError
12+
from sse_starlette import EventSourceResponse
13+
from starlette.background import BackgroundTask
14+
from starlette.requests import Request
15+
from starlette.responses import Response
16+
from starlette.types import Receive, Scope, Send
17+
18+
from .session import Session, SessionManager
19+
20+
logger = logging.getLogger(__name__)
21+
22+
def patch_meta_data(body: bytes, **kwargs) -> bytes:
23+
data = json.loads(body.decode("utf-8"))
24+
if "params" not in data:
25+
data["params"] = {}
26+
27+
for key, value in kwargs.items():
28+
data["params"].setdefault("_meta", {})[key] = value
29+
return json.dumps(data).encode("utf-8")
30+
31+
class SseTransport:
32+
33+
def __init__(self, endpoint: str, session_manager: SessionManager) -> None:
34+
self.session_manager = session_manager
35+
self._endpoint = endpoint
36+
37+
@asynccontextmanager
38+
async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
39+
session_id_hex = scope["session_id"]
40+
session_id: UUID = UUID(hex=session_id_hex)
41+
session = await self.session_manager.get_session(session_id)
42+
43+
session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}"
44+
45+
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, Any]](0)
46+
47+
async def sse_writer():
48+
logger.debug("Starting SSE writer")
49+
async with sse_stream_writer, session["write_stream_reader"]:
50+
await sse_stream_writer.send({"event": "endpoint", "data": session_uri})
51+
logger.debug(f"Sent endpoint event: {session_uri}")
52+
53+
async for message in session["write_stream_reader"]:
54+
logger.debug(f"Sending message via SSE: {message}")
55+
await sse_stream_writer.send(
56+
{
57+
"event": "message",
58+
"data": message.model_dump_json(by_alias=True, exclude_none=True),
59+
}
60+
)
61+
62+
async with anyio.create_task_group() as tg:
63+
async def on_client_disconnect():
64+
await self.session_manager.close_session(session_id)
65+
66+
try:
67+
response = EventSourceResponse(
68+
content=sse_stream_reader,
69+
data_sender_callable=sse_writer,
70+
background=BackgroundTask(on_client_disconnect),
71+
)
72+
logger.debug("Starting SSE response task")
73+
tg.start_soon(response, scope, receive, send)
74+
75+
logger.debug("Yielding read and write streams")
76+
# Due to limitations with interrupting the MCP server run operation,
77+
# this will always block here regardless of client disconnection status
78+
yield (session["read_stream"], session["write_stream"])
79+
except asyncio.CancelledError as exc:
80+
logger.warning(f"SSE connection for session {session_id} was cancelled")
81+
tg.cancel_scope.cancel()
82+
# raise the exception again so that to interrupt mcp server run operation
83+
raise exc
84+
finally:
85+
# for server shutdown
86+
await self.session_manager.cleanup_resources()
87+
88+
async def handle_post_message(self, scope: Scope, receive: Receive, send: Send):
89+
90+
session_id = scope["session_id"]
91+
session: Session = await self.session_manager.get_session(UUID(hex=session_id))
92+
93+
request = Request(scope, receive)
94+
body = await request.body()
95+
# patch meta data
96+
body = patch_meta_data(body, **session["meta"])
97+
98+
# send message to writer
99+
writer = session["read_stream_writer"]
100+
try:
101+
message = types.JSONRPCMessage.model_validate_json(body)
102+
logger.debug(f"Validated client message: {message}")
103+
except ValidationError as err:
104+
logger.error(f"Failed to parse message: {err}")
105+
response = Response("Could not parse message", status_code=400)
106+
await response(scope, receive, send)
107+
try:
108+
await writer.send(err)
109+
except (BrokenPipeError, ConnectionError, OSError) as pipe_err:
110+
logger.warning(f"Failed to send error due to pipe issue: {pipe_err}")
111+
return
112+
113+
logger.debug(f"Sending message to writer: {message}")
114+
response = Response("Accepted", status_code=202)
115+
await response(scope, receive, send)
116+
117+
# add error handling, catch possible pipe errors
118+
try:
119+
await writer.send(message)
120+
except (BrokenPipeError, ConnectionError, OSError) as e:
121+
# if it's EPIPE error or other connection error, log it but don't throw an exception
122+
if isinstance(e, OSError) and e.errno == 32: # EPIPE
123+
logger.warning(f"EPIPE error when sending message to session {session_id}, connection may be closing")
124+
else:
125+
logger.warning(f"Connection error when sending message to session {session_id}: {e}")
126+
await self.session_manager.close_session(session_id)

src/mcpm/utils/errlog_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,6 @@ def close_errlog_file(self, server_id: str) -> None:
3030
del self._log_files[server_id]
3131

3232
def close_all(self) -> None:
33-
for server_id in self._log_files:
33+
keys = list(self._log_files.keys())
34+
for server_id in keys:
3435
self.close_errlog_file(server_id)

0 commit comments

Comments
 (0)