diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index b2079bb27..d36686720 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -185,20 +185,22 @@ async def handle_streamable_http(scope, receive, send): ) server_instances[http_transport.mcp_session_id] = http_transport logger.info(f"Created new transport with session ID: {new_session_id}") - async with http_transport.connect() as streams: - read_stream, write_stream = streams - async def run_server(): - await app.run( - read_stream, - write_stream, - app.create_initialization_options(), - ) + async def run_server(task_status=None): + async with http_transport.connect() as streams: + read_stream, write_stream = streams + if task_status: + task_status.started() + await app.run( + read_stream, + write_stream, + app.create_initialization_options(), + ) if not task_group: raise RuntimeError("Task group is not initialized") - task_group.start_soon(run_server) + await task_group.start(run_server) # Handle the HTTP request and return the response await http_transport.handle_request(scope, receive, send) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 1cacd23b5..4b97b33da 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -480,7 +480,7 @@ async def run( # but also make tracing exceptions much easier during testing and when using # in-process servers. raise_exceptions: bool = False, - # When True, the server as stateless deployments where + # When True, the server is stateless and # clients can perform initialization with any node. The client must still follow # the initialization lifecycle, but can do so with any available node # rather than requiring initialization for each connection. diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 53fff0d36..ace74b33b 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -129,6 +129,8 @@ class StreamableHTTPServerTransport: _read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = ( None ) + _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None + _write_stream: MemoryObjectSendStream[SessionMessage] | None = None _write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None def __init__( @@ -163,7 +165,11 @@ def __init__( self.is_json_response_enabled = is_json_response_enabled self._event_store = event_store self._request_streams: dict[ - RequestId, MemoryObjectSendStream[EventMessage] + RequestId, + tuple[ + MemoryObjectSendStream[EventMessage], + MemoryObjectReceiveStream[EventMessage], + ], ] = {} self._terminated = False @@ -239,6 +245,19 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: return event_data + async def _clean_up_memory_streams(self, request_id: RequestId) -> None: + """Clean up memory streams for a given request ID.""" + if request_id in self._request_streams: + try: + # Close the request stream + await self._request_streams[request_id][0].aclose() + await self._request_streams[request_id][1].aclose() + except Exception as e: + logger.debug(f"Error closing memory streams: {e}") + finally: + # Remove the request stream from the mapping + self._request_streams.pop(request_id, None) + async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: """Application entry point that handles all HTTP requests""" request = Request(scope, receive) @@ -386,13 +405,11 @@ async def _handle_post_request( # Extract the request ID outside the try block for proper scope request_id = str(message.root.id) - # Create promise stream for getting response - request_stream_writer, request_stream_reader = ( - anyio.create_memory_object_stream[EventMessage](0) - ) - # Register this stream for the request ID - self._request_streams[request_id] = request_stream_writer + self._request_streams[request_id] = anyio.create_memory_object_stream[ + EventMessage + ](0) + request_stream_reader = self._request_streams[request_id][1] if self.is_json_response_enabled: # Process the message @@ -441,11 +458,7 @@ async def _handle_post_request( ) await response(scope, receive, send) finally: - # Clean up the request stream - if request_id in self._request_streams: - self._request_streams.pop(request_id, None) - await request_stream_reader.aclose() - await request_stream_writer.aclose() + await self._clean_up_memory_streams(request_id) else: # Create SSE stream sse_stream_writer, sse_stream_reader = ( @@ -467,16 +480,12 @@ async def sse_writer(): event_message.message.root, JSONRPCResponse | JSONRPCError, ): - if request_id: - self._request_streams.pop(request_id, None) break except Exception as e: logger.exception(f"Error in SSE writer: {e}") finally: logger.debug("Closing SSE writer") - # Clean up the request-specific streams - if request_id and request_id in self._request_streams: - self._request_streams.pop(request_id, None) + await self._clean_up_memory_streams(request_id) # Create and start EventSourceResponse # SSE stream mode (original behavior) @@ -507,9 +516,9 @@ async def sse_writer(): await writer.send(session_message) except Exception: logger.exception("SSE response error") - # Clean up the request stream if something goes wrong - if request_id and request_id in self._request_streams: - self._request_streams.pop(request_id, None) + await sse_stream_writer.aclose() + await sse_stream_reader.aclose() + await self._clean_up_memory_streams(request_id) except Exception as err: logger.exception("Error handling POST request") @@ -581,12 +590,11 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: async def standalone_sse_writer(): try: # Create a standalone message stream for server-initiated messages - standalone_stream_writer, standalone_stream_reader = ( + + self._request_streams[GET_STREAM_KEY] = ( anyio.create_memory_object_stream[EventMessage](0) ) - - # Register this stream using the special key - self._request_streams[GET_STREAM_KEY] = standalone_stream_writer + standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1] async with sse_stream_writer, standalone_stream_reader: # Process messages from the standalone stream @@ -603,8 +611,7 @@ async def standalone_sse_writer(): logger.exception(f"Error in standalone SSE writer: {e}") finally: logger.debug("Closing standalone SSE writer") - # Remove the stream from request_streams - self._request_streams.pop(GET_STREAM_KEY, None) + await self._clean_up_memory_streams(GET_STREAM_KEY) # Create and start EventSourceResponse response = EventSourceResponse( @@ -618,8 +625,9 @@ async def standalone_sse_writer(): await response(request.scope, request.receive, send) except Exception as e: logger.exception(f"Error in standalone SSE response: {e}") - # Clean up the request stream - self._request_streams.pop(GET_STREAM_KEY, None) + await sse_stream_writer.aclose() + await sse_stream_reader.aclose() + await self._clean_up_memory_streams(GET_STREAM_KEY) async def _handle_delete_request(self, request: Request, send: Send) -> None: """Handle DELETE requests for explicit session termination.""" @@ -636,7 +644,7 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None: if not await self._validate_session(request, send): return - self._terminate_session() + await self._terminate_session() response = self._create_json_response( None, @@ -644,7 +652,7 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None: ) await response(request.scope, request.receive, send) - def _terminate_session(self) -> None: + async def _terminate_session(self) -> None: """Terminate the current session, closing all streams. Once terminated, all requests with this session ID will receive 404 Not Found. @@ -656,19 +664,26 @@ def _terminate_session(self) -> None: # We need a copy of the keys to avoid modification during iteration request_stream_keys = list(self._request_streams.keys()) - # Close all request streams (synchronously) + # Close all request streams asynchronously for key in request_stream_keys: try: - # Get the stream - stream = self._request_streams.get(key) - if stream: - # We must use close() here, not aclose() since this is a sync method - stream.close() + await self._clean_up_memory_streams(key) except Exception as e: logger.debug(f"Error closing stream {key} during termination: {e}") # Clear the request streams dictionary immediately self._request_streams.clear() + try: + if self._read_stream_writer is not None: + await self._read_stream_writer.aclose() + if self._read_stream is not None: + await self._read_stream.aclose() + if self._write_stream_reader is not None: + await self._write_stream_reader.aclose() + if self._write_stream is not None: + await self._write_stream.aclose() + except Exception as e: + logger.debug(f"Error closing streams: {e}") async def _handle_unsupported_request(self, request: Request, send: Send) -> None: """Handle unsupported HTTP methods.""" @@ -756,10 +771,10 @@ async def send_event(event_message: EventMessage) -> None: # If stream ID not in mapping, create it if stream_id and stream_id not in self._request_streams: - msg_writer, msg_reader = anyio.create_memory_object_stream[ - EventMessage - ](0) - self._request_streams[stream_id] = msg_writer + self._request_streams[stream_id] = ( + anyio.create_memory_object_stream[EventMessage](0) + ) + msg_reader = self._request_streams[stream_id][1] # Forward messages to SSE async with msg_reader: @@ -781,6 +796,9 @@ async def send_event(event_message: EventMessage) -> None: await response(request.scope, request.receive, send) except Exception as e: logger.exception(f"Error in replay response: {e}") + finally: + await sse_stream_writer.aclose() + await sse_stream_reader.aclose() except Exception as e: logger.exception(f"Error replaying events: {e}") @@ -818,7 +836,9 @@ async def connect( # Store the streams self._read_stream_writer = read_stream_writer + self._read_stream = read_stream self._write_stream_reader = write_stream_reader + self._write_stream = write_stream # Start a task group for message routing async with anyio.create_task_group() as tg: @@ -863,7 +883,7 @@ async def message_router(): if request_stream_id in self._request_streams: try: # Send both the message and the event ID - await self._request_streams[request_stream_id].send( + await self._request_streams[request_stream_id][0].send( EventMessage(message, event_id) ) except ( @@ -872,6 +892,12 @@ async def message_router(): ): # Stream might be closed, remove from registry self._request_streams.pop(request_stream_id, None) + else: + logging.debug( + f"""Request stream {request_stream_id} not found + for message. Still processing message as the client + might reconnect and replay.""" + ) except Exception as e: logger.exception(f"Error in message router: {e}") @@ -882,9 +908,19 @@ async def message_router(): # Yield the streams for the caller to use yield read_stream, write_stream finally: - for stream in list(self._request_streams.values()): + for stream_id in list(self._request_streams.keys()): try: - await stream.aclose() - except Exception: + await self._clean_up_memory_streams(stream_id) + except Exception as e: + logger.debug(f"Error closing request stream: {e}") pass self._request_streams.clear() + + # Clean up the read and write streams + try: + await read_stream_writer.aclose() + await read_stream.aclose() + await write_stream_reader.aclose() + await write_stream.aclose() + except Exception as e: + logger.debug(f"Error closing streams: {e}") diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index f64360229..b1dc7ea33 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -234,29 +234,30 @@ async def handle_streamable_http(scope, receive, send): event_store=event_store, ) - async with http_transport.connect() as streams: - read_stream, write_stream = streams - - async def run_server(): + async def run_server(task_status=None): + async with http_transport.connect() as streams: + read_stream, write_stream = streams + if task_status: + task_status.started() await server.run( read_stream, write_stream, server.create_initialization_options(), ) - if task_group is None: - response = Response( - "Internal Server Error: Task group is not initialized", - status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - ) - await response(scope, receive, send) - return + if task_group is None: + response = Response( + "Internal Server Error: Task group is not initialized", + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + ) + await response(scope, receive, send) + return - # Store the instance before starting the task to prevent races - server_instances[http_transport.mcp_session_id] = http_transport - task_group.start_soon(run_server) + # Store the instance before starting the task to prevent races + server_instances[http_transport.mcp_session_id] = http_transport + await task_group.start(run_server) - await http_transport.handle_request(scope, receive, send) + await http_transport.handle_request(scope, receive, send) else: response = Response( "Bad Request: No valid session ID provided",