Skip to content

Commit f29df0f

Browse files
committed
streamable http server wise connection
1 parent 6f9db42 commit f29df0f

File tree

3 files changed

+183
-79
lines changed

3 files changed

+183
-79
lines changed

mcpgateway/main.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@
105105
)
106106
from mcpgateway.transports.sse_transport import SSETransport
107107
from mcpgateway.transports.streamablehttp_transport import (
108-
JWTAuthMiddlewareStreamableHttp,
109108
SessionManagerWrapper,
109+
streamable_http_auth,
110110
)
111111
from mcpgateway.types import (
112112
InitializeRequest,
@@ -197,7 +197,7 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
197197
await logging_service.initialize()
198198
await sampling_handler.initialize()
199199
await resource_cache.initialize()
200-
await streamable_http_session.start()
200+
await streamable_http_session.initialize()
201201

202202
logger.info("All services initialized successfully")
203203
yield
@@ -262,6 +262,54 @@ async def dispatch(self, request: Request, call_next):
262262
return await call_next(request)
263263

264264

265+
class MCPPathRewriteMiddleware:
266+
"""
267+
Supports requests like '/servers/<server_id>/mcp' by rewriting the path to '/mcp'.
268+
269+
- Only rewrites paths ending with '/mcp' but not exactly '/mcp'.
270+
- Performs authentication before rewriting.
271+
- Passes rewritten requests to `streamable_http_session`.
272+
- All other requests are passed through unchanged.
273+
"""
274+
275+
def __init__(self, app):
276+
"""
277+
Initialize the middleware with the ASGI application.
278+
279+
Args:
280+
app (Callable): The next ASGI application in the middleware stack.
281+
"""
282+
self.app = app
283+
284+
async def __call__(self, scope, receive, send):
285+
"""
286+
Intercept and potentially rewrite the incoming HTTP request path.
287+
288+
Args:
289+
scope (dict): The ASGI connection scope.
290+
receive (Callable): Awaitable that yields events from the client.
291+
send (Callable): Awaitable used to send events to the client.
292+
"""
293+
# Only handle HTTP requests, HTTPS uses scope["type"] == "http" in ASGI
294+
if scope["type"] != "http":
295+
await self.app(scope, receive, send)
296+
return
297+
298+
# Call auth check first
299+
auth_ok = await streamable_http_auth(scope, receive, send)
300+
if not auth_ok:
301+
return
302+
303+
original_path = scope.get("path", "")
304+
scope["modified_path"] = original_path
305+
if (original_path.endswith("/mcp") and original_path != "/mcp") or (original_path.endswith("/mcp/") and original_path != "/mcp/"):
306+
# Rewrite path so mounted app at /mcp handles it
307+
scope["path"] = "/mcp"
308+
await streamable_http_session.handle_streamable_http(scope, receive, send)
309+
return
310+
await self.app(scope, receive, send)
311+
312+
265313
# Configure CORS
266314
app.add_middleware(
267315
CORSMiddleware,
@@ -276,8 +324,9 @@ async def dispatch(self, request: Request, call_next):
276324
# Add custom DocsAuthMiddleware
277325
app.add_middleware(DocsAuthMiddleware)
278326

279-
# Add streamable HTTP middleware for JWT auth
280-
app.add_middleware(JWTAuthMiddlewareStreamableHttp)
327+
# Add streamable HTTP middleware for /mcp routes
328+
app.add_middleware(MCPPathRewriteMiddleware)
329+
281330

282331
# Set up Jinja2 templates and store in app state for later use
283332
templates = Jinja2Templates(directory=str(settings.templates_dir))

mcpgateway/transports/streamablehttp_transport.py

Lines changed: 124 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,16 @@
99
1010
Key components include:
1111
- SessionManagerWrapper: Manages the lifecycle of streamable HTTP sessions
12-
- JWTAuthMiddlewareStreamableHttp: Middleware for JWT authentication
1312
- Configuration options for:
1413
1. stateful/stateless operation
1514
2. JSON response mode or SSE streams
1615
- InMemoryEventStore: A simple in-memory event storage system for maintaining session state
1716
1817
"""
1918

19+
import contextvars
2020
import logging
21+
import re
2122
from collections import deque
2223
from contextlib import AsyncExitStack, asynccontextmanager
2324
from dataclasses import dataclass
@@ -38,10 +39,9 @@
3839
from mcp.types import JSONRPCMessage
3940
from starlette.datastructures import Headers
4041
from starlette.middleware.base import BaseHTTPMiddleware
41-
from starlette.requests import Request
4242
from starlette.responses import JSONResponse
4343
from starlette.status import HTTP_401_UNAUTHORIZED
44-
from starlette.types import ASGIApp, Receive, Scope, Send
44+
from starlette.types import Receive, Scope, Send
4545

4646
from mcpgateway.config import settings
4747
from mcpgateway.db import SessionLocal
@@ -55,6 +55,8 @@
5555
tool_service = ToolService()
5656
mcp_app = Server("mcp-streamable-http-stateless")
5757

58+
server_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("server_id", default=None)
59+
5860
## ------------------------------ Event store ------------------------------
5961

6062

@@ -191,13 +193,24 @@ async def list_tools() -> List[types.Tool]:
191193
A list of Tool objects containing metadata such as name, description, and input schema.
192194
Logs and returns an empty list on failure.
193195
"""
194-
try:
195-
async with get_db() as db:
196-
tools = await tool_service.list_tools(db)
197-
return [types.Tool(name=tool.name, description=tool.description, inputSchema=tool.input_schema) for tool in tools]
198-
except Exception as e:
199-
logger.exception("Error listing tools")
200-
return []
196+
server_id = server_id_var.get()
197+
198+
if server_id:
199+
try:
200+
async with get_db() as db:
201+
tools = await tool_service.list_server_tools(db, server_id)
202+
return [types.Tool(name=tool.name, description=tool.description, inputSchema=tool.input_schema) for tool in tools]
203+
except Exception as e:
204+
logger.exception("Error listing tools")
205+
return []
206+
else:
207+
try:
208+
async with get_db() as db:
209+
tools = await tool_service.list_tools(db)
210+
return [types.Tool(name=tool.name, description=tool.description, inputSchema=tool.input_schema) for tool in tools]
211+
except Exception as e:
212+
logger.exception("Error listing tools")
213+
return []
201214

202215

203216
class SessionManagerWrapper:
@@ -226,7 +239,7 @@ def __init__(self) -> None:
226239
)
227240
self.stack = AsyncExitStack()
228241

229-
async def start(self) -> None:
242+
async def initialize(self) -> None:
230243
"""
231244
Starts the Streamable HTTP session manager context.
232245
"""
@@ -250,80 +263,122 @@ async def handle_streamable_http(self, scope: Scope, receive: Receive, send: Sen
250263
send (Send): ASGI send callable.
251264
Logs any exceptions that occur during request handling.
252265
"""
266+
267+
path = scope["modified_path"]
268+
match = re.search(r"/servers/(?P<server_id>\d+)/mcp", path)
269+
270+
if match:
271+
server_id = match.group("server_id")
272+
server_id_var.set(server_id)
273+
253274
try:
254275
await self.session_manager.handle_request(scope, receive, send)
255276
except Exception as e:
256277
logger.exception("Error handling streamable HTTP request")
257278
raise
258279

259280

260-
## ------------------------- FastAPI Middleware for Authentication ------------------------------
281+
## ------------------------- Authentication for /mcp routes ------------------------------
261282

283+
# async def streamable_http_auth(scope, receive, send):
284+
# """
285+
# Perform authentication check in middleware context (ASGI scope).
262286

263-
class JWTAuthMiddlewareStreamableHttp(BaseHTTPMiddleware):
264-
"""
265-
Middleware for handling JWT authentication in an ASGI application.
266-
This middleware checks for JWT tokens in the authorization header or cookies
267-
and verifies the credentials before allowing access to protected routes.
287+
# If path does not end with "/mcp", just continue (return True).
288+
289+
# If auth fails, sends 401 JSONResponse and returns False.
290+
291+
# If auth succeeds or not required, returns True.
292+
# """
293+
294+
# path = scope.get("path", "")
295+
# if not path.endswith("/mcp"):
296+
# # No auth needed for other paths in this middleware usage
297+
# return True
298+
299+
# headers = Headers(scope=scope)
300+
# authorization = headers.get("authorization")
301+
# cookie_header = headers.get("cookie", "")
302+
303+
# token = None
304+
# if authorization:
305+
# scheme, credentials = get_authorization_scheme_param(authorization)
306+
# if scheme.lower() == "bearer" and credentials:
307+
# token = credentials
308+
309+
# if not token:
310+
# # parse cookie header manually
311+
# for cookie in cookie_header.split(";"):
312+
# if cookie.strip().startswith("jwt_token="):
313+
# token = cookie.strip().split("=", 1)[1]
314+
# break
315+
316+
# if settings.auth_required and not token:
317+
# response = JSONResponse(
318+
# {"detail": "Not authenticated"},
319+
# status_code=HTTP_401_UNAUTHORIZED,
320+
# headers={"WWW-Authenticate": "Bearer"},
321+
# )
322+
# await response(scope, receive, send)
323+
# return False
324+
325+
# if token:
326+
# try:
327+
# await verify_credentials(token)
328+
# except Exception:
329+
# response = JSONResponse(
330+
# {"detail": "Authentication failed"},
331+
# status_code=HTTP_401_UNAUTHORIZED,
332+
# headers={"WWW-Authenticate": "Bearer"},
333+
# )
334+
# await response(scope, receive, send)
335+
# return False
336+
337+
# return True
338+
339+
340+
async def streamable_http_auth(scope, receive, send):
268341
"""
342+
Perform authentication check in middleware context (ASGI scope).
269343
270-
def __init__(self, app: ASGIApp):
271-
"""
272-
Initialize the middleware with the given ASGI application.
344+
If path does not end with "/mcp", just continue (return True).
273345
274-
Args:
275-
app (ASGIApp): The ASGI application to wrap.
276-
"""
277-
super().__init__(app)
346+
Only check Authorization header for Bearer token.
347+
If no Bearer token provided, allow (return True).
278348
279-
async def dispatch(self, request: Request, call_next):
280-
"""
281-
Dispatch the request to the appropriate handler after performing JWT authentication.
349+
If auth_required is True and Bearer token provided, verify it.
350+
If verification fails, send 401 JSONResponse and return False.
351+
"""
282352

283-
Args:
284-
request (Request): The incoming request.
285-
call_next: The next middleware or route handler in the chain.
353+
path = scope.get("path", "")
354+
if not path.endswith("/mcp") and not path.endswith("/mcp/"):
355+
# No auth needed for other paths in this middleware usage
356+
return True
286357

287-
Returns:
288-
JSONResponse: A response indicating authentication failure if the token is invalid or missing.
289-
Response: The response from the next middleware or route handler if authentication is successful.
290-
"""
291-
# Only apply auth to /mcp path
292-
if not request.url.path.startswith("/mcp"):
293-
return await call_next(request)
294-
295-
headers = Headers(scope=request.scope)
296-
authorization = headers.get("authorization")
297-
cookie_header = headers.get("cookie", "")
298-
299-
token = None
300-
if authorization:
301-
scheme, credentials = get_authorization_scheme_param(authorization)
302-
if scheme.lower() == "bearer" and credentials:
303-
token = credentials
304-
305-
if not token:
306-
for cookie in cookie_header.split(";"):
307-
if cookie.strip().startswith("jwt_token="):
308-
token = cookie.strip().split("=", 1)[1]
309-
break
358+
headers = Headers(scope=scope)
359+
authorization = headers.get("authorization")
310360

311-
try:
312-
if settings.auth_required and not token:
313-
return JSONResponse(
314-
{"detail": "Not authenticated"},
315-
status_code=HTTP_401_UNAUTHORIZED,
316-
headers={"WWW-Authenticate": "Bearer"},
317-
)
361+
token = None
362+
if authorization:
363+
scheme, credentials = get_authorization_scheme_param(authorization)
364+
if scheme.lower() == "bearer" and credentials:
365+
token = credentials
318366

319-
if token:
320-
await verify_credentials(token)
367+
# # If no Bearer token in Authorization header, just allow (no auth)
368+
# if not token:
369+
# return True
321370

322-
return await call_next(request)
371+
# If token is present, verify it
372+
print("TOKEN::::", token)
373+
try:
374+
await verify_credentials(token)
375+
except Exception:
376+
response = JSONResponse(
377+
{"detail": "Authentication failed"},
378+
status_code=HTTP_401_UNAUTHORIZED,
379+
headers={"WWW-Authenticate": "Bearer"},
380+
)
381+
await response(scope, receive, send)
382+
return False
323383

324-
except Exception as e:
325-
return JSONResponse(
326-
{"detail": "Authentication failed"},
327-
status_code=HTTP_401_UNAUTHORIZED,
328-
headers={"WWW-Authenticate": "Bearer"},
329-
)
384+
return True

tests/unit/mcpgateway/test_admin.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,18 +109,15 @@ def mock_request():
109109
)
110110

111111
# Basic template rendering stub
112-
request.app = MagicMock() # ensure .app exists
113-
request.app.state = MagicMock() # ensure .app.state exists
112+
request.app = MagicMock() # ensure .app exists
113+
request.app.state = MagicMock() # ensure .app.state exists
114114
request.app.state.templates = MagicMock()
115-
request.app.state.templates.TemplateResponse.return_value = HTMLResponse(
116-
content="<html></html>"
117-
)
115+
request.app.state.templates.TemplateResponse.return_value = HTMLResponse(content="<html></html>")
118116

119117
request.query_params = {"include_inactive": "false"}
120118
return request
121119

122120

123-
124121
class TestAdminServerRoutes:
125122
"""Test admin routes for server management."""
126123

@@ -217,6 +214,7 @@ async def test_admin_delete_server(self, mock_delete_server, mock_request, mock_
217214
assert result.status_code == 303
218215
assert "/admin#catalog" in result.headers["location"]
219216

217+
220218
class TestAdminToolRoutes:
221219
"""Test admin routes for tool management."""
222220

@@ -602,6 +600,7 @@ async def test_admin_delete_gateway(self, mock_delete_gateway, mock_request, moc
602600
assert result.status_code == 303
603601
assert "/admin#gateways" in result.headers["location"]
604602

603+
605604
class TestAdminRootRoutes:
606605
"""Test admin routes for root management."""
607606

@@ -626,6 +625,7 @@ async def test_admin_delete_root(self, mock_remove_root, mock_request):
626625
assert result.status_code == 303
627626
assert "/admin#roots" in result.headers["location"]
628627

628+
629629
class TestAdminMetricsRoutes:
630630
"""Test admin routes for metrics management."""
631631

0 commit comments

Comments
 (0)