11
11
- SessionManagerWrapper: Manages the lifecycle of streamable HTTP sessions
12
12
- JWTAuthMiddlewareStreamableHttp: Middleware for JWT authentication
13
13
- Configuration options for:
14
- 1. stateful/stateless operation
14
+ 1. stateful/stateless operation
15
15
2. JSON response mode or SSE streams
16
16
- InMemoryEventStore: A simple in-memory event storage system for maintaining session state
17
17
18
18
"""
19
19
20
20
import logging
21
+ from collections import deque
21
22
from contextlib import AsyncExitStack , asynccontextmanager
23
+ from dataclasses import dataclass
22
24
from typing import List , Union
23
-
24
- from starlette .types import Receive , Scope , Send
25
- from starlette .middleware .base import BaseHTTPMiddleware
26
- from starlette .requests import Request
27
- from starlette .responses import JSONResponse
28
- from starlette .datastructures import Headers
29
- from fastapi .security .utils import get_authorization_scheme_param
30
- from starlette .status import HTTP_401_UNAUTHORIZED
31
- from starlette .types import ASGIApp
25
+ from uuid import uuid4
32
26
33
27
import mcp .types as types
28
+ from fastapi .security .utils import get_authorization_scheme_param
34
29
from mcp .server .lowlevel import Server
35
- from mcp .server .streamable_http_manager import StreamableHTTPSessionManager
36
-
37
- from mcpgateway .services .tool_service import ToolService
38
- from mcpgateway .db import SessionLocal
39
- from mcpgateway .config import settings
40
- from mcpgateway .utils .verify_credentials import verify_credentials
41
-
42
-
43
- from collections import deque
44
- from dataclasses import dataclass
45
- from uuid import uuid4
46
-
47
30
from mcp .server .streamable_http import (
48
31
EventCallback ,
49
32
EventId ,
50
33
EventMessage ,
51
34
EventStore ,
52
35
StreamId ,
53
36
)
37
+ from mcp .server .streamable_http_manager import StreamableHTTPSessionManager
54
38
from mcp .types import JSONRPCMessage
39
+ from starlette .datastructures import Headers
40
+ from starlette .middleware .base import BaseHTTPMiddleware
41
+ from starlette .requests import Request
42
+ from starlette .responses import JSONResponse
43
+ from starlette .status import HTTP_401_UNAUTHORIZED
44
+ from starlette .types import ASGIApp , Receive , Scope , Send
55
45
46
+ from mcpgateway .config import settings
47
+ from mcpgateway .db import SessionLocal
48
+ from mcpgateway .services .tool_service import ToolService
49
+ from mcpgateway .utils .verify_credentials import verify_credentials
56
50
57
51
logger = logging .getLogger (__name__ )
58
52
logging .basicConfig (level = logging .INFO )
63
57
64
58
## ------------------------------ Event store ------------------------------
65
59
60
+
66
61
@dataclass
67
62
class EventEntry :
68
63
"""
@@ -95,14 +90,10 @@ def __init__(self, max_events_per_stream: int = 100):
95
90
# event_id -> EventEntry for quick lookup
96
91
self .event_index : dict [EventId , EventEntry ] = {}
97
92
98
- async def store_event (
99
- self , stream_id : StreamId , message : JSONRPCMessage
100
- ) -> EventId :
93
+ async def store_event (self , stream_id : StreamId , message : JSONRPCMessage ) -> EventId :
101
94
"""Stores an event with a generated event ID."""
102
95
event_id = str (uuid4 ())
103
- event_entry = EventEntry (
104
- event_id = event_id , stream_id = stream_id , message = message
105
- )
96
+ event_entry = EventEntry (event_id = event_id , stream_id = stream_id , message = message )
106
97
107
98
# Get or create deque for this stream
108
99
if stream_id not in self .streams :
@@ -148,6 +139,7 @@ async def replay_events_after(
148
139
149
140
## ------------------------------ Streamable HTTP Transport ------------------------------
150
141
142
+
151
143
@asynccontextmanager
152
144
async def get_db ():
153
145
"""
@@ -184,12 +176,7 @@ async def call_tool(name: str, arguments: dict) -> List[Union[types.TextContent,
184
176
logger .warning (f"No content returned by tool: { name } " )
185
177
return []
186
178
187
- return [
188
- types .TextContent (
189
- type = result .content [0 ].type ,
190
- text = result .content [0 ].text
191
- )
192
- ]
179
+ return [types .TextContent (type = result .content [0 ].type , text = result .content [0 ].text )]
193
180
except Exception as e :
194
181
logger .exception (f"Error calling tool '{ name } ': { e } " )
195
182
return []
@@ -207,19 +194,12 @@ async def list_tools() -> List[types.Tool]:
207
194
try :
208
195
async with get_db () as db :
209
196
tools = await tool_service .list_tools (db )
210
- return [
211
- types .Tool (
212
- name = tool .name ,
213
- description = tool .description ,
214
- inputSchema = tool .input_schema
215
- ) for tool in tools
216
- ]
197
+ return [types .Tool (name = tool .name , description = tool .description , inputSchema = tool .input_schema ) for tool in tools ]
217
198
except Exception as e :
218
199
logger .exception ("Error listing tools" )
219
200
return []
220
201
221
202
222
-
223
203
class SessionManagerWrapper :
224
204
"""
225
205
Wrapper class for managing the lifecycle of a StreamableHTTPSessionManager instance.
@@ -230,13 +210,13 @@ def __init__(self) -> None:
230
210
"""
231
211
Initializes the session manager and the exit stack used for managing its lifecycle.
232
212
"""
233
-
213
+
234
214
if settings .use_stateful_sessions :
235
215
event_store = InMemoryEventStore ()
236
- stateless = False
216
+ stateless = False
237
217
else :
238
- event_store = None
239
- stateless = True
218
+ event_store = None
219
+ stateless = True
240
220
241
221
self .session_manager = StreamableHTTPSessionManager (
242
222
app = mcp_app ,
@@ -252,7 +232,7 @@ async def start(self) -> None:
252
232
"""
253
233
logger .info ("Initializing Streamable HTTP service" )
254
234
await self .stack .enter_async_context (self .session_manager .run ())
255
-
235
+
256
236
async def shutdown (self ) -> None :
257
237
"""
258
238
Gracefully shuts down the Streamable HTTP session manager.
@@ -276,14 +256,17 @@ async def handle_streamable_http(self, scope: Scope, receive: Receive, send: Sen
276
256
logger .exception ("Error handling streamable HTTP request" )
277
257
raise
278
258
259
+
279
260
## ------------------------- FastAPI Middleware for Authentication ------------------------------
280
261
262
+
281
263
class JWTAuthMiddlewareStreamableHttp (BaseHTTPMiddleware ):
282
264
"""
283
265
Middleware for handling JWT authentication in an ASGI application.
284
266
This middleware checks for JWT tokens in the authorization header or cookies
285
267
and verifies the credentials before allowing access to protected routes.
286
268
"""
269
+
287
270
def __init__ (self , app : ASGIApp ):
288
271
"""
289
272
Initialize the middleware with the given ASGI application.
@@ -344,5 +327,3 @@ async def dispatch(self, request: Request, call_next):
344
327
status_code = HTTP_401_UNAUTHORIZED ,
345
328
headers = {"WWW-Authenticate" : "Bearer" },
346
329
)
347
-
348
-
0 commit comments