Skip to content

Commit 2943631

Browse files
committed
Fix code complexity issue in sHTTP
1 parent 07a2821 commit 2943631

File tree

1 file changed

+90
-62
lines changed

1 file changed

+90
-62
lines changed

src/mcp/server/streamable_http.py

Lines changed: 90 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,93 @@ def _check_content_type(self, request: Request) -> bool:
308308

309309
return any(part == CONTENT_TYPE_JSON for part in content_type_parts)
310310

311+
def _is_async_operation_response(self, response_message: JSONRPCMessage) -> bool:
312+
"""Check if response is for an async operation that should keep stream open."""
313+
try:
314+
if not isinstance(response_message.root, JSONRPCResponse):
315+
return False
316+
317+
result = response_message.root.result
318+
if not result:
319+
return False
320+
321+
# Check if result has _operation with token
322+
if hasattr(result, "__getitem__") and "_operation" in result:
323+
operation = result["_operation"] # type: ignore
324+
if hasattr(operation, "__getitem__") and "token" in operation:
325+
return bool(operation["token"]) # type: ignore
326+
327+
return False
328+
except (TypeError, KeyError, AttributeError):
329+
return False
330+
331+
async def _handle_sse_mode(
332+
self,
333+
message: JSONRPCMessage,
334+
request: Request,
335+
writer: MemoryObjectSendStream[SessionMessage | Exception],
336+
request_id: str,
337+
request_stream_reader: MemoryObjectReceiveStream[EventMessage],
338+
scope: Scope,
339+
receive: Receive,
340+
send: Send,
341+
) -> None:
342+
"""Handle SSE response mode."""
343+
# Create SSE stream
344+
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
345+
346+
async def sse_writer():
347+
# Get the request ID from the incoming request message
348+
try:
349+
async with sse_stream_writer, request_stream_reader:
350+
# Process messages from the request-specific stream
351+
async for event_message in request_stream_reader:
352+
# Build the event data
353+
event_data = self._create_event_data(event_message)
354+
await sse_stream_writer.send(event_data)
355+
356+
# If response, remove from pending streams and close
357+
if isinstance(
358+
event_message.message.root,
359+
JSONRPCResponse | JSONRPCError,
360+
):
361+
break
362+
except Exception:
363+
logger.exception("Error in SSE writer")
364+
finally:
365+
logger.debug("Closing SSE writer")
366+
await self._clean_up_memory_streams(request_id)
367+
368+
# Create and start EventSourceResponse
369+
# SSE stream mode (original behavior)
370+
# Set up headers
371+
headers = {
372+
"Cache-Control": "no-cache, no-transform",
373+
"Connection": "keep-alive",
374+
"Content-Type": CONTENT_TYPE_SSE,
375+
**({MCP_SESSION_ID_HEADER: self.mcp_session_id} if self.mcp_session_id else {}),
376+
}
377+
response = EventSourceResponse(
378+
content=sse_stream_reader,
379+
data_sender_callable=sse_writer,
380+
headers=headers,
381+
)
382+
383+
# Start the SSE response (this will send headers immediately)
384+
try:
385+
# First send the response to establish the SSE connection
386+
async with anyio.create_task_group() as tg:
387+
tg.start_soon(response, scope, receive, send)
388+
# Then send the message to be processed by the server
389+
metadata = ServerMessageMetadata(request_context=request)
390+
session_message = SessionMessage(message, metadata=metadata)
391+
await writer.send(session_message)
392+
except Exception:
393+
logger.exception("SSE response error")
394+
await sse_stream_writer.aclose()
395+
await sse_stream_reader.aclose()
396+
await self._clean_up_memory_streams(request_id)
397+
311398
async def _handle_post_request(self, scope: Scope, request: Request, receive: Receive, send: Send) -> None:
312399
"""Handle POST requests containing JSON-RPC messages."""
313400
writer = self._read_stream_writer
@@ -420,15 +507,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
420507
# At this point we should have a response
421508
if response_message:
422509
# Check if this is an async operation response - keep stream open
423-
if (
424-
isinstance(response_message.root, JSONRPCResponse)
425-
and response_message.root.result
426-
and "_operation" in response_message.root.result
427-
and (
428-
("token" in response_message.root.result["_operation"])
429-
and response_message.root.result["_operation"]["token"]
430-
)
431-
):
510+
if self._is_async_operation_response(response_message):
432511
# This is an async operation - keep the stream open for elicitation/sampling
433512
should_pop_stream = False
434513

@@ -455,61 +534,10 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
455534
if should_pop_stream:
456535
await self._clean_up_memory_streams(request_id)
457536
else:
458-
# Create SSE stream
459-
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
460-
461-
async def sse_writer():
462-
# Get the request ID from the incoming request message
463-
try:
464-
async with sse_stream_writer, request_stream_reader:
465-
# Process messages from the request-specific stream
466-
async for event_message in request_stream_reader:
467-
# Build the event data
468-
event_data = self._create_event_data(event_message)
469-
await sse_stream_writer.send(event_data)
470-
471-
# If response, remove from pending streams and close
472-
if isinstance(
473-
event_message.message.root,
474-
JSONRPCResponse | JSONRPCError,
475-
):
476-
break
477-
except Exception:
478-
logger.exception("Error in SSE writer")
479-
finally:
480-
logger.debug("Closing SSE writer")
481-
await self._clean_up_memory_streams(request_id)
482-
483-
# Create and start EventSourceResponse
484-
# SSE stream mode (original behavior)
485-
# Set up headers
486-
headers = {
487-
"Cache-Control": "no-cache, no-transform",
488-
"Connection": "keep-alive",
489-
"Content-Type": CONTENT_TYPE_SSE,
490-
**({MCP_SESSION_ID_HEADER: self.mcp_session_id} if self.mcp_session_id else {}),
491-
}
492-
response = EventSourceResponse(
493-
content=sse_stream_reader,
494-
data_sender_callable=sse_writer,
495-
headers=headers,
537+
await self._handle_sse_mode(
538+
message, request, writer, request_id, request_stream_reader, scope, receive, send
496539
)
497540

498-
# Start the SSE response (this will send headers immediately)
499-
try:
500-
# First send the response to establish the SSE connection
501-
async with anyio.create_task_group() as tg:
502-
tg.start_soon(response, scope, receive, send)
503-
# Then send the message to be processed by the server
504-
metadata = ServerMessageMetadata(request_context=request)
505-
session_message = SessionMessage(message, metadata=metadata)
506-
await writer.send(session_message)
507-
except Exception:
508-
logger.exception("SSE response error")
509-
await sse_stream_writer.aclose()
510-
await sse_stream_reader.aclose()
511-
await self._clean_up_memory_streams(request_id)
512-
513541
except Exception as err:
514542
logger.exception("Error handling POST request")
515543
response = self._create_error_response(

0 commit comments

Comments
 (0)