Skip to content

Commit 67e06e5

Browse files
committed
Fix: Ensure ASGI compliance in app.py SSE handling
1 parent 27dc4d6 commit 67e06e5

File tree

1 file changed

+119
-3
lines changed

1 file changed

+119
-3
lines changed

src/mcpm/router/app.py

Lines changed: 119 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,21 @@
11
import logging
22
import os
3+
import asyncio
4+
import re
5+
from contextlib import asynccontextmanager
6+
7+
from starlette.applications import Starlette
8+
from starlette.middleware import Middleware
9+
from starlette.middleware.cors import CORSMiddleware
10+
from starlette.requests import Request
11+
from starlette.responses import JSONResponse, Response
12+
from starlette.routing import Mount, Route
313

414
from mcpm.monitor.event import monitor
15+
from mcpm.monitor.base import AccessEventType
516
from mcpm.router.router import MCPRouter
617
from mcpm.router.router_config import RouterConfig
18+
from mcpm.router.transport import RouterSseTransport
719
from mcpm.utils.config import ConfigManager
820
from mcpm.utils.platform import get_log_directory
921

@@ -23,10 +35,114 @@
2335
api_key = config.get("api_key")
2436
auth_enabled = config.get("auth_enabled", False)
2537

38+
router_instance = MCPRouter(reload_server=True, router_config=RouterConfig(api_key=api_key, auth_enabled=auth_enabled))
39+
sse_transport = RouterSseTransport("/messages/", api_key=api_key if auth_enabled else None)
40+
41+
class NoOpsResponse(Response):
42+
def __init__(self):
43+
super().__init__(content=b"", status_code=204)
44+
45+
async def __call__(self, scope, receive, send):
46+
await send(
47+
{
48+
"type": "http.response.start",
49+
"status": self.status_code,
50+
"headers": self.render_headers(),
51+
}
52+
)
53+
await send({"type": "http.response.body", "body": b"", "more_body": False})
54+
55+
async def handle_sse(request: Request):
56+
try:
57+
async with sse_transport.connect_sse(
58+
request.scope,
59+
request.receive,
60+
request._send,
61+
) as (read_stream, write_stream):
62+
await router_instance.aggregated_server.run(
63+
read_stream,
64+
write_stream,
65+
router_instance.aggregated_server.initialization_options,
66+
)
67+
while not await request.is_disconnected():
68+
await asyncio.sleep(0.1)
69+
except asyncio.CancelledError:
70+
raise
71+
except Exception as e:
72+
logger.error(f"Unexpected error in app.py handle_sse: {e}", exc_info=True)
73+
finally:
74+
return NoOpsResponse()
75+
76+
async def handle_query_events(request: Request) -> Response:
77+
try:
78+
offset = request.query_params.get("offset")
79+
page = int(request.query_params.get("page", 1))
80+
limit = int(request.query_params.get("limit", 10))
81+
event_type_str = request.query_params.get("event_type", None)
2682

27-
allow_origins = None
83+
if offset is None:
84+
return JSONResponse(
85+
{"error": "Missing required parameter", "detail": "The 'offset' parameter is required."},
86+
status_code=400,
87+
)
88+
89+
offset_pattern = r"^(\d+)([hdwm])$"
90+
match = re.match(offset_pattern, offset.lower())
91+
if not match:
92+
return JSONResponse(
93+
{"error": "Invalid offset format", "detail": "Offset must be e.g., '24h', '7d', '2w', '1m'."},
94+
status_code=400,
95+
)
96+
97+
if page < 1:
98+
page = 1
99+
event_type = None
100+
if event_type_str:
101+
try:
102+
event_type = AccessEventType[event_type_str.upper()].name
103+
except (KeyError, ValueError):
104+
return JSONResponse(
105+
{"error": "Invalid event type", "detail": f"Valid types: {', '.join([e.name for e in AccessEventType])}"},
106+
status_code=400,
107+
)
108+
109+
if monitor:
110+
response_data = await monitor.query_events(offset, page, limit, event_type)
111+
return JSONResponse(response_data.model_dump(), status_code=200)
112+
else:
113+
logger.warning("monitor object not available for /events route")
114+
return JSONResponse({"error": "Monitoring not available"}, status_code=503)
115+
116+
except Exception as e:
117+
logger.error(f"Error handling query events request: {e}", exc_info=True)
118+
return JSONResponse({"error": str(e)}, status_code=500)
119+
120+
@asynccontextmanager
121+
async def lifespan(app_starlette: Starlette):
122+
logger.info("Starting MCPRouter (via app.py)...")
123+
await router_instance.initialize_router()
124+
if monitor:
125+
await monitor.initialize_storage()
126+
yield
127+
logger.info("Shutting down MCPRouter (via app.py)...")
128+
await router_instance.shutdown()
129+
if monitor:
130+
await monitor.close()
131+
132+
middlewares = []
28133
if CORS_ENABLED:
29134
allow_origins = os.environ.get("MCPM_ROUTER_CORS", "").split(",")
135+
middlewares.append(
136+
Middleware(CORSMiddleware, allow_origins=allow_origins, allow_methods=["*"], allow_headers=["*"])
137+
)
30138

31-
router = MCPRouter(reload_server=True, router_config=RouterConfig(api_key=api_key, auth_enabled=auth_enabled))
32-
app = router.get_remote_server_app(allow_origins=allow_origins, include_lifespan=True, monitor=monitor)
139+
app = Starlette(
140+
debug=os.environ.get("MCPM_DEBUG") == "true",
141+
middleware=middlewares,
142+
routes=[
143+
Route("/sse", endpoint=handle_sse),
144+
Route("/events", endpoint=handle_query_events, methods=["GET"]),
145+
Mount("/messages/", app=sse_transport.handle_post_message),
146+
],
147+
lifespan=lifespan,
148+
)

0 commit comments

Comments
 (0)