Skip to content

Commit 17bb99c

Browse files
authored
Merge pull request #118 from IBM/streamablehttp_server_wise_connection
Streamablehttp server wise connection
2 parents 8221586 + a4b868c commit 17bb99c

File tree

4 files changed

+126
-86
lines changed

4 files changed

+126
-86
lines changed

mcpgateway/main.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@
105105
)
106106
from mcpgateway.transports.sse_transport import SSETransport
107107
from mcpgateway.transports.streamablehttp_transport import (
108-
JWTAuthMiddlewareStreamableHttp,
109108
SessionManagerWrapper,
109+
streamable_http_auth,
110110
)
111111
from mcpgateway.types import (
112112
InitializeRequest,
@@ -197,7 +197,7 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
197197
await logging_service.initialize()
198198
await sampling_handler.initialize()
199199
await resource_cache.initialize()
200-
await streamable_http_session.start()
200+
await streamable_http_session.initialize()
201201

202202
logger.info("All services initialized successfully")
203203
yield
@@ -262,6 +262,54 @@ async def dispatch(self, request: Request, call_next):
262262
return await call_next(request)
263263

264264

265+
class MCPPathRewriteMiddleware:
266+
"""
267+
Supports requests like '/servers/<server_id>/mcp' by rewriting the path to '/mcp'.
268+
269+
- Only rewrites paths ending with '/mcp' but not exactly '/mcp'.
270+
- Performs authentication before rewriting.
271+
- Passes rewritten requests to `streamable_http_session`.
272+
- All other requests are passed through unchanged.
273+
"""
274+
275+
def __init__(self, app):
276+
"""
277+
Initialize the middleware with the ASGI application.
278+
279+
Args:
280+
app (Callable): The next ASGI application in the middleware stack.
281+
"""
282+
self.app = app
283+
284+
async def __call__(self, scope, receive, send):
285+
"""
286+
Intercept and potentially rewrite the incoming HTTP request path.
287+
288+
Args:
289+
scope (dict): The ASGI connection scope.
290+
receive (Callable): Awaitable that yields events from the client.
291+
send (Callable): Awaitable used to send events to the client.
292+
"""
293+
# Only handle HTTP requests, HTTPS uses scope["type"] == "http" in ASGI
294+
if scope["type"] != "http":
295+
await self.app(scope, receive, send)
296+
return
297+
298+
# Call auth check first
299+
auth_ok = await streamable_http_auth(scope, receive, send)
300+
if not auth_ok:
301+
return
302+
303+
original_path = scope.get("path", "")
304+
scope["modified_path"] = original_path
305+
if (original_path.endswith("/mcp") and original_path != "/mcp") or (original_path.endswith("/mcp/") and original_path != "/mcp/"):
306+
# Rewrite path so mounted app at /mcp handles it
307+
scope["path"] = "/mcp"
308+
await streamable_http_session.handle_streamable_http(scope, receive, send)
309+
return
310+
await self.app(scope, receive, send)
311+
312+
265313
# Configure CORS
266314
app.add_middleware(
267315
CORSMiddleware,
@@ -276,8 +324,9 @@ async def dispatch(self, request: Request, call_next):
276324
# Add custom DocsAuthMiddleware
277325
app.add_middleware(DocsAuthMiddleware)
278326

279-
# Add streamable HTTP middleware for JWT auth
280-
app.add_middleware(JWTAuthMiddlewareStreamableHttp)
327+
# Add streamable HTTP middleware for /mcp routes
328+
app.add_middleware(MCPPathRewriteMiddleware)
329+
281330

282331
# Set up Jinja2 templates and store in app state for later use
283332
templates = Jinja2Templates(directory=str(settings.templates_dir))

mcpgateway/schemas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def assemble_auth(cls, values: Dict[str, Any]) -> Dict[str, Any]:
319319
auth_type = values.get("auth_type")
320320
if auth_type:
321321
if auth_type.lower() == "basic":
322-
creds = base64.b64encode(f'{values.get("auth_username", "")}:{values.get("auth_password", "")}'.encode("utf-8")).decode()
322+
creds = base64.b64encode(f"{values.get('auth_username', '')}:{values.get('auth_password', '')}".encode("utf-8")).decode()
323323
encoded_auth = encode_auth({"Authorization": f"Basic {creds}"})
324324
values["auth"] = {"auth_type": "basic", "auth_value": encoded_auth}
325325
elif auth_type.lower() == "bearer":
@@ -378,7 +378,7 @@ def assemble_auth(cls, values: Dict[str, Any]) -> Dict[str, Any]:
378378
auth_type = values.get("auth_type")
379379
if auth_type:
380380
if auth_type.lower() == "basic":
381-
creds = base64.b64encode(f'{values.get("auth_username", "")}:{values.get("auth_password", "")}'.encode("utf-8")).decode()
381+
creds = base64.b64encode(f"{values.get('auth_username', '')}:{values.get('auth_password', '')}".encode("utf-8")).decode()
382382
encoded_auth = encode_auth({"Authorization": f"Basic {creds}"})
383383
values["auth"] = {"auth_type": "basic", "auth_value": encoded_auth}
384384
elif auth_type.lower() == "bearer":

mcpgateway/transports/streamablehttp_transport.py

Lines changed: 65 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,23 @@
99
1010
Key components include:
1111
- SessionManagerWrapper: Manages the lifecycle of streamable HTTP sessions
12-
- JWTAuthMiddlewareStreamableHttp: Middleware for JWT authentication
1312
- Configuration options for:
1413
1. stateful/stateless operation
1514
2. JSON response mode or SSE streams
1615
- InMemoryEventStore: A simple in-memory event storage system for maintaining session state
1716
1817
"""
1918

19+
import contextvars
2020
import logging
21+
import re
2122
from collections import deque
2223
from contextlib import AsyncExitStack, asynccontextmanager
2324
from dataclasses import dataclass
2425
from typing import List, Union
2526
from uuid import uuid4
2627

27-
import mcp.types as types
28+
from mcp import types
2829
from fastapi.security.utils import get_authorization_scheme_param
2930
from mcp.server.lowlevel import Server
3031
from mcp.server.streamable_http import (
@@ -37,11 +38,9 @@
3738
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
3839
from mcp.types import JSONRPCMessage
3940
from starlette.datastructures import Headers
40-
from starlette.middleware.base import BaseHTTPMiddleware
41-
from starlette.requests import Request
4241
from starlette.responses import JSONResponse
4342
from starlette.status import HTTP_401_UNAUTHORIZED
44-
from starlette.types import ASGIApp, Receive, Scope, Send
43+
from starlette.types import Receive, Scope, Send
4544

4645
from mcpgateway.config import settings
4746
from mcpgateway.db import SessionLocal
@@ -55,6 +54,8 @@
5554
tool_service = ToolService()
5655
mcp_app = Server("mcp-streamable-http-stateless")
5756

57+
server_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("server_id", default=None)
58+
5859
## ------------------------------ Event store ------------------------------
5960

6061

@@ -191,13 +192,24 @@ async def list_tools() -> List[types.Tool]:
191192
A list of Tool objects containing metadata such as name, description, and input schema.
192193
Logs and returns an empty list on failure.
193194
"""
194-
try:
195-
async with get_db() as db:
196-
tools = await tool_service.list_tools(db)
197-
return [types.Tool(name=tool.name, description=tool.description, inputSchema=tool.input_schema) for tool in tools]
198-
except Exception as e:
199-
logger.exception("Error listing tools")
200-
return []
195+
server_id = server_id_var.get()
196+
197+
if server_id:
198+
try:
199+
async with get_db() as db:
200+
tools = await tool_service.list_server_tools(db, server_id)
201+
return [types.Tool(name=tool.name, description=tool.description, inputSchema=tool.input_schema) for tool in tools]
202+
except Exception as e:
203+
logger.exception(f"Error listing tools:{e}")
204+
return []
205+
else:
206+
try:
207+
async with get_db() as db:
208+
tools = await tool_service.list_tools(db)
209+
return [types.Tool(name=tool.name, description=tool.description, inputSchema=tool.input_schema) for tool in tools]
210+
except Exception as e:
211+
logger.exception(f"Error listing tools:{e}")
212+
return []
201213

202214

203215
class SessionManagerWrapper:
@@ -226,7 +238,7 @@ def __init__(self) -> None:
226238
)
227239
self.stack = AsyncExitStack()
228240

229-
async def start(self) -> None:
241+
async def initialize(self) -> None:
230242
"""
231243
Starts the Streamable HTTP session manager context.
232244
"""
@@ -250,80 +262,59 @@ async def handle_streamable_http(self, scope: Scope, receive: Receive, send: Sen
250262
send (Send): ASGI send callable.
251263
Logs any exceptions that occur during request handling.
252264
"""
265+
266+
path = scope["modified_path"]
267+
match = re.search(r"/servers/(?P<server_id>\d+)/mcp", path)
268+
269+
if match:
270+
server_id = match.group("server_id")
271+
server_id_var.set(server_id)
272+
253273
try:
254274
await self.session_manager.handle_request(scope, receive, send)
255275
except Exception as e:
256-
logger.exception("Error handling streamable HTTP request")
276+
logger.exception(f"Error handling streamable HTTP request: {e}")
257277
raise
258278

259279

260-
## ------------------------- FastAPI Middleware for Authentication ------------------------------
280+
## ------------------------- Authentication for /mcp routes ------------------------------
261281

262282

263-
class JWTAuthMiddlewareStreamableHttp(BaseHTTPMiddleware):
264-
"""
265-
Middleware for handling JWT authentication in an ASGI application.
266-
This middleware checks for JWT tokens in the authorization header or cookies
267-
and verifies the credentials before allowing access to protected routes.
283+
async def streamable_http_auth(scope, receive, send):
268284
"""
285+
Perform authentication check in middleware context (ASGI scope).
269286
270-
def __init__(self, app: ASGIApp):
271-
"""
272-
Initialize the middleware with the given ASGI application.
273-
274-
Args:
275-
app (ASGIApp): The ASGI application to wrap.
276-
"""
277-
super().__init__(app)
287+
If path does not end with "/mcp", just continue (return True).
278288
279-
async def dispatch(self, request: Request, call_next):
280-
"""
281-
Dispatch the request to the appropriate handler after performing JWT authentication.
282-
283-
Args:
284-
request (Request): The incoming request.
285-
call_next: The next middleware or route handler in the chain.
289+
Only check Authorization header for Bearer token.
290+
If no Bearer token provided, allow (return True).
286291
287-
Returns:
288-
JSONResponse: A response indicating authentication failure if the token is invalid or missing.
289-
Response: The response from the next middleware or route handler if authentication is successful.
290-
"""
291-
# Only apply auth to /mcp path
292-
if not request.url.path.startswith("/mcp"):
293-
return await call_next(request)
294-
295-
headers = Headers(scope=request.scope)
296-
authorization = headers.get("authorization")
297-
cookie_header = headers.get("cookie", "")
298-
299-
token = None
300-
if authorization:
301-
scheme, credentials = get_authorization_scheme_param(authorization)
302-
if scheme.lower() == "bearer" and credentials:
303-
token = credentials
304-
305-
if not token:
306-
for cookie in cookie_header.split(";"):
307-
if cookie.strip().startswith("jwt_token="):
308-
token = cookie.strip().split("=", 1)[1]
309-
break
292+
If auth_required is True and Bearer token provided, verify it.
293+
If verification fails, send 401 JSONResponse and return False.
294+
"""
310295

311-
try:
312-
if settings.auth_required and not token:
313-
return JSONResponse(
314-
{"detail": "Not authenticated"},
315-
status_code=HTTP_401_UNAUTHORIZED,
316-
headers={"WWW-Authenticate": "Bearer"},
317-
)
296+
path = scope.get("path", "")
297+
if not path.endswith("/mcp") and not path.endswith("/mcp/"):
298+
# No auth needed for other paths in this middleware usage
299+
return True
318300

319-
if token:
320-
await verify_credentials(token)
301+
headers = Headers(scope=scope)
302+
authorization = headers.get("authorization")
321303

322-
return await call_next(request)
304+
token = None
305+
if authorization:
306+
scheme, credentials = get_authorization_scheme_param(authorization)
307+
if scheme.lower() == "bearer" and credentials:
308+
token = credentials
309+
try:
310+
await verify_credentials(token)
311+
except Exception:
312+
response = JSONResponse(
313+
{"detail": "Authentication failed"},
314+
status_code=HTTP_401_UNAUTHORIZED,
315+
headers={"WWW-Authenticate": "Bearer"},
316+
)
317+
await response(scope, receive, send)
318+
return False
323319

324-
except Exception as e:
325-
return JSONResponse(
326-
{"detail": "Authentication failed"},
327-
status_code=HTTP_401_UNAUTHORIZED,
328-
headers={"WWW-Authenticate": "Bearer"},
329-
)
320+
return True

tests/unit/mcpgateway/test_admin.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,18 +109,15 @@ def mock_request():
109109
)
110110

111111
# Basic template rendering stub
112-
request.app = MagicMock() # ensure .app exists
113-
request.app.state = MagicMock() # ensure .app.state exists
112+
request.app = MagicMock() # ensure .app exists
113+
request.app.state = MagicMock() # ensure .app.state exists
114114
request.app.state.templates = MagicMock()
115-
request.app.state.templates.TemplateResponse.return_value = HTMLResponse(
116-
content="<html></html>"
117-
)
115+
request.app.state.templates.TemplateResponse.return_value = HTMLResponse(content="<html></html>")
118116

119117
request.query_params = {"include_inactive": "false"}
120118
return request
121119

122120

123-
124121
class TestAdminServerRoutes:
125122
"""Test admin routes for server management."""
126123

@@ -217,6 +214,7 @@ async def test_admin_delete_server(self, mock_delete_server, mock_request, mock_
217214
assert result.status_code == 303
218215
assert "/admin#catalog" in result.headers["location"]
219216

217+
220218
class TestAdminToolRoutes:
221219
"""Test admin routes for tool management."""
222220

@@ -602,6 +600,7 @@ async def test_admin_delete_gateway(self, mock_delete_gateway, mock_request, moc
602600
assert result.status_code == 303
603601
assert "/admin#gateways" in result.headers["location"]
604602

603+
605604
class TestAdminRootRoutes:
606605
"""Test admin routes for root management."""
607606

@@ -626,6 +625,7 @@ async def test_admin_delete_root(self, mock_remove_root, mock_request):
626625
assert result.status_code == 303
627626
assert "/admin#roots" in result.headers["location"]
628627

628+
629629
class TestAdminMetricsRoutes:
630630
"""Test admin routes for metrics management."""
631631

0 commit comments

Comments
 (0)