|
28 | 28 | TransportSecurityMiddleware, |
29 | 29 | TransportSecuritySettings, |
30 | 30 | ) |
| 31 | +from mcp.shared.context import CloseSSEStreamCallback |
31 | 32 | from mcp.shared.message import ServerMessageMetadata, SessionMessage |
32 | 33 | from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS |
33 | 34 | from mcp.types import ( |
@@ -283,6 +284,26 @@ async def _create_priming_event(self, stream_id: str) -> dict[str, str | int] | |
283 | 284 |
|
284 | 285 | return event_data |
285 | 286 |
|
| 287 | + def _create_close_sse_stream_callback(self, request_id: RequestId) -> CloseSSEStreamCallback | None: |
| 288 | + """Create a bound callback for closing SSE streams. |
| 289 | +
|
| 290 | + Args: |
| 291 | + request_id: The request ID to bind to the callback |
| 292 | +
|
| 293 | + Returns: |
| 294 | + A callback that closes the SSE stream for this request, |
| 295 | + or None if no event store is configured (events would be lost). |
| 296 | + """ |
| 297 | + # Only provide callback if event store is configured |
| 298 | + # Without an event store, closing the stream would lose events |
| 299 | + if self._event_store is None: |
| 300 | + return None |
| 301 | + |
| 302 | + async def callback(retry_interval: int | None = None) -> bool: |
| 303 | + return await self.close_sse_stream(request_id, retry_interval) |
| 304 | + |
| 305 | + return callback |
| 306 | + |
286 | 307 | async def _clean_up_memory_streams(self, request_id: RequestId) -> None: # pragma: no cover |
287 | 308 | """Clean up memory streams for a given request ID.""" |
288 | 309 | if request_id in self._request_streams: |
@@ -544,7 +565,12 @@ async def sse_writer(): |
544 | 565 | async with anyio.create_task_group() as tg: |
545 | 566 | tg.start_soon(response, scope, receive, send) |
546 | 567 | # Then send the message to be processed by the server |
547 | | - metadata = ServerMessageMetadata(request_context=request) |
| 568 | + # Create callback for closing SSE stream (only if event store configured) |
| 569 | + close_callback = self._create_close_sse_stream_callback(request_id) |
| 570 | + metadata = ServerMessageMetadata( |
| 571 | + request_context=request, |
| 572 | + close_sse_stream=close_callback, |
| 573 | + ) |
548 | 574 | session_message = SessionMessage(message, metadata=metadata) |
549 | 575 | await writer.send(session_message) |
550 | 576 | except Exception: |
@@ -716,26 +742,35 @@ async def terminate(self) -> None: |
716 | 742 | # During cleanup, we catch all exceptions since streams might be in various states |
717 | 743 | logger.debug(f"Error closing streams: {e}") |
718 | 744 |
|
719 | | - async def close_sse_stream(self, request_id: RequestId) -> None: |
| 745 | + async def close_sse_stream(self, request_id: RequestId, retry_interval: int | None = None) -> bool: |
720 | 746 | """Close an SSE stream for a specific request, triggering client reconnection. |
721 | 747 |
|
722 | 748 | Use this to implement polling behavior during long-running operations - |
723 | 749 | client will reconnect after the retry interval specified in the priming event. |
724 | 750 |
|
725 | 751 | Args: |
726 | 752 | request_id: The request ID (or stream key) of the stream to close |
| 753 | + retry_interval: Optional retry interval in ms to send before closing. |
| 754 | + If provided, overrides the transport's default retry interval. |
| 755 | +
|
| 756 | + Returns: |
| 757 | + True if the stream was found and closed, False otherwise. |
727 | 758 | """ |
728 | 759 | request_id_str = str(request_id) |
729 | | - if request_id_str in self._request_streams: |
730 | | - try: |
731 | | - sender, receiver = self._request_streams[request_id_str] |
732 | | - await sender.aclose() |
733 | | - await receiver.aclose() |
734 | | - except Exception: # pragma: no cover |
735 | | - # Stream might already be closed |
736 | | - logger.debug(f"Error closing SSE stream {request_id_str} - may already be closed") |
737 | | - finally: |
738 | | - self._request_streams.pop(request_id_str, None) |
| 760 | + if request_id_str not in self._request_streams: |
| 761 | + return False |
| 762 | + |
| 763 | + try: |
| 764 | + sender, receiver = self._request_streams[request_id_str] |
| 765 | + await sender.aclose() |
| 766 | + await receiver.aclose() |
| 767 | + return True |
| 768 | + except Exception: # pragma: no cover |
| 769 | + # Stream might already be closed |
| 770 | + logger.debug(f"Error closing SSE stream {request_id_str} - may already be closed") |
| 771 | + return False |
| 772 | + finally: |
| 773 | + self._request_streams.pop(request_id_str, None) |
739 | 774 |
|
740 | 775 | async def _handle_unsupported_request(self, request: Request, send: Send) -> None: # pragma: no cover |
741 | 776 | """Handle unsupported HTTP methods.""" |
|
0 commit comments