|
1 | 1 | import logging |
2 | 2 | from contextlib import asynccontextmanager |
3 | 3 | from typing import Any, cast |
4 | | -from uuid import UUID |
| 4 | +from uuid import UUID, uuid4 |
5 | 5 |
|
6 | 6 | import anyio |
7 | | -from anyio import CapacityLimiter, lowlevel |
8 | | -from pydantic import ValidationError |
| 7 | +from anyio import CancelScope, CapacityLimiter, Event, lowlevel |
| 8 | +from anyio.abc import TaskGroup |
9 | 9 |
|
10 | 10 | import mcp.types as types |
11 | | -from mcp.server.message_queue.base import MessageCallback |
| 11 | +from mcp.server.message_queue.base import MessageCallback, MessageWrapper |
12 | 12 |
|
13 | 13 | try: |
14 | 14 | import redis.asyncio as redis |
@@ -42,98 +42,212 @@ def __init__( |
42 | 42 | self._prefix = prefix |
43 | 43 | self._active_sessions_key = f"{prefix}active_sessions" |
44 | 44 | self._callbacks: dict[UUID, MessageCallback] = {} |
45 | | - # Ensures only one polling task runs at a time for message handling |
46 | 45 | self._limiter = CapacityLimiter(1) |
| 46 | + self._ack_events: dict[str, Event] = {} |
| 47 | + |
47 | 48 | logger.debug(f"Redis message dispatch initialized: {redis_url}") |
48 | 49 |
|
49 | 50 | def _session_channel(self, session_id: UUID) -> str: |
50 | 51 | """Get the Redis channel for a session.""" |
51 | 52 | return f"{self._prefix}session:{session_id.hex}" |
52 | 53 |
|
| 54 | + def _ack_channel(self, session_id: UUID) -> str: |
| 55 | + """Get the acknowledgment channel for a session.""" |
| 56 | + return f"{self._prefix}ack:{session_id.hex}" |
| 57 | + |
53 | 58 | @asynccontextmanager |
54 | 59 | async def subscribe(self, session_id: UUID, callback: MessageCallback): |
55 | 60 | """Request-scoped context manager that subscribes to messages for a session.""" |
56 | 61 | await self._redis.sadd(self._active_sessions_key, session_id.hex) |
57 | 62 | self._callbacks[session_id] = callback |
58 | | - channel = self._session_channel(session_id) |
59 | | - await self._pubsub.subscribe(channel) # type: ignore |
60 | | - |
61 | | - logger.debug(f"Subscribing to Redis channel for session {session_id}") |
62 | | - async with anyio.create_task_group() as tg: |
63 | | - tg.start_soon(self._listen_for_messages) |
64 | | - try: |
65 | | - yield |
66 | | - finally: |
67 | | - tg.cancel_scope.cancel() |
68 | | - await self._pubsub.unsubscribe(channel) # type: ignore |
69 | | - await self._redis.srem(self._active_sessions_key, session_id.hex) |
70 | | - del self._callbacks[session_id] |
71 | | - logger.debug(f"Unsubscribed from Redis channel: {session_id}") |
72 | | - |
73 | | - async def _listen_for_messages(self) -> None: |
| 63 | + |
| 64 | + session_channel = self._session_channel(session_id) |
| 65 | + ack_channel = self._ack_channel(session_id) |
| 66 | + |
| 67 | + await self._pubsub.subscribe(session_channel) # type: ignore |
| 68 | + await self._pubsub.subscribe(ack_channel) # type: ignore |
| 69 | + |
| 70 | + logger.debug(f"Subscribing to Redis channels for session {session_id}") |
| 71 | + |
| 72 | + # Two nested task groups ensure proper cleanup: the inner one cancels the |
| 73 | + # listener, while the outer one allows any handlers to complete before exiting. |
| 74 | + async with anyio.create_task_group() as tg_handler: |
| 75 | + async with anyio.create_task_group() as tg: |
| 76 | + tg.start_soon(self._listen_for_messages, tg_handler) |
| 77 | + try: |
| 78 | + yield |
| 79 | + finally: |
| 80 | + tg.cancel_scope.cancel() |
| 81 | + await self._pubsub.unsubscribe(session_channel) # type: ignore |
| 82 | + await self._pubsub.unsubscribe(ack_channel) # type: ignore |
| 83 | + await self._redis.srem(self._active_sessions_key, session_id.hex) |
| 84 | + del self._callbacks[session_id] |
| 85 | + logger.debug( |
| 86 | + f"Unsubscribed from Redis channels for session {session_id}" |
| 87 | + ) |
| 88 | + |
| 89 | + async def _listen_for_messages(self, tg_handler: TaskGroup) -> None: |
74 | 90 | """Background task that listens for messages on subscribed channels.""" |
75 | 91 | async with self._limiter: |
76 | 92 | while True: |
77 | 93 | await lowlevel.checkpoint() |
78 | | - message: None | dict[str, Any] = await self._pubsub.get_message( # type: ignore |
79 | | - ignore_subscribe_messages=True, |
80 | | - timeout=None, # type: ignore |
81 | | - ) |
82 | | - if message is None: |
83 | | - continue |
84 | | - |
85 | | - channel: str = cast(str, message["channel"]) |
86 | | - expected_prefix = f"{self._prefix}session:" |
| 94 | + # Shield message retrieval from cancellation to ensure no messages are |
| 95 | + # lost when a session disconnects during processing. |
| 96 | + with CancelScope(shield=True): |
| 97 | + redis_message: ( # type: ignore |
| 98 | + None | dict[str, Any] |
| 99 | + ) = await self._pubsub.get_message( # type: ignore |
| 100 | + ignore_subscribe_messages=True, |
| 101 | + timeout=0.1, # type: ignore |
| 102 | + ) |
| 103 | + if redis_message is None: |
| 104 | + continue |
87 | 105 |
|
88 | | - if not channel.startswith(expected_prefix): |
89 | | - logger.debug(f"Ignoring message from non-MCP channel: {channel}") |
90 | | - continue |
| 106 | + channel: str = cast(str, redis_message["channel"]) |
| 107 | + data: str = cast(str, redis_message["data"]) |
91 | 108 |
|
92 | | - session_hex = channel[len(expected_prefix) :] |
93 | | - try: |
94 | | - session_id = UUID(hex=session_hex) |
95 | | - expected_channel = self._session_channel(session_id) |
96 | | - if channel != expected_channel: |
97 | | - logger.error(f"Channel format mismatch: {channel}") |
| 109 | + # Handle acknowledgment messages |
| 110 | + if channel.startswith(f"{self._prefix}ack:"): |
| 111 | + tg_handler.start_soon(self._handle_ack_message, channel, data) |
98 | 112 | continue |
99 | | - except ValueError: |
100 | | - logger.error(f"Invalid UUID in channel: {channel}") |
101 | | - continue |
102 | 113 |
|
103 | | - data: str = cast(str, message["data"]) |
104 | | - try: |
105 | | - if session_id not in self._callbacks: |
106 | | - logger.warning(f"Message dropped: no callback for {session_id}") |
| 114 | + # Handle session messages |
| 115 | + elif channel.startswith(f"{self._prefix}session:"): |
| 116 | + tg_handler.start_soon( |
| 117 | + self._handle_session_message, channel, data |
| 118 | + ) |
107 | 119 | continue |
108 | 120 |
|
109 | | - # Try to parse as valid message or recreate original ValidationError |
110 | | - try: |
111 | | - msg = types.JSONRPCMessage.model_validate_json(data) |
112 | | - await self._callbacks[session_id](msg) |
113 | | - except ValidationError as exc: |
114 | | - # Pass the identical validation error that would have occurred |
115 | | - await self._callbacks[session_id](exc) |
116 | | - except Exception as e: |
117 | | - logger.error(f"Error processing message for {session_id}: {e}") |
| 121 | + # Ignore other channels |
| 122 | + else: |
| 123 | + logger.debug( |
| 124 | + f"Ignoring message from non-MCP channel: {channel}" |
| 125 | + ) |
| 126 | + |
| 127 | + async def _handle_ack_message(self, channel: str, data: str) -> None: |
| 128 | + """Handle acknowledgment messages received on ack channels.""" |
| 129 | + ack_prefix = f"{self._prefix}ack:" |
| 130 | + if not channel.startswith(ack_prefix): |
| 131 | + return |
| 132 | + |
| 133 | + # Validate channel format exactly matches our expected format |
| 134 | + session_hex = channel[len(ack_prefix) :] |
| 135 | + try: |
| 136 | + # Validate this is a valid UUID hex and channel has correct format |
| 137 | + session_id = UUID(hex=session_hex) |
| 138 | + expected_channel = self._ack_channel(session_id) |
| 139 | + if channel != expected_channel: |
| 140 | + logger.error( |
| 141 | + f"Channel mismatch: got {channel}, expected {expected_channel}" |
| 142 | + ) |
| 143 | + return |
| 144 | + except ValueError: |
| 145 | + logger.error(f"Invalid UUID hex in ack channel: {channel}") |
| 146 | + return |
| 147 | + |
| 148 | + # Extract message ID from data |
| 149 | + message_id = data.strip() |
| 150 | + if message_id in self._ack_events: |
| 151 | + logger.debug(f"Received acknowledgment for message: {message_id}") |
| 152 | + self._ack_events[message_id].set() |
| 153 | + |
| 154 | + async def _handle_session_message(self, channel: str, data: str) -> None: |
| 155 | + """Handle regular messages received on session channels.""" |
| 156 | + session_prefix = f"{self._prefix}session:" |
| 157 | + if not channel.startswith(session_prefix): |
| 158 | + return |
| 159 | + |
| 160 | + session_hex = channel[len(session_prefix) :] |
| 161 | + try: |
| 162 | + session_id = UUID(hex=session_hex) |
| 163 | + expected_channel = self._session_channel(session_id) |
| 164 | + if channel != expected_channel: |
| 165 | + logger.error( |
| 166 | + f"Channel mismatch: got {channel}, expected {expected_channel}" |
| 167 | + ) |
| 168 | + return |
| 169 | + except ValueError: |
| 170 | + logger.error(f"Invalid UUID hex in session channel: {channel}") |
| 171 | + return |
| 172 | + |
| 173 | + if session_id not in self._callbacks: |
| 174 | + logger.warning(f"Message dropped: no callback for {session_id}") |
| 175 | + return |
| 176 | + |
| 177 | + try: |
| 178 | + wrapper = MessageWrapper.model_validate_json(data) |
| 179 | + result = wrapper.get_json_rpc_message() |
| 180 | + await self._callbacks[session_id](result) |
| 181 | + await self._send_acknowledgment(session_id, wrapper.message_id) |
| 182 | + |
| 183 | + except Exception as e: |
| 184 | + logger.error(f"Error processing message for {session_id}: {e}") |
| 185 | + |
| 186 | + async def _send_acknowledgment(self, session_id: UUID, message_id: str) -> None: |
| 187 | + """Send an acknowledgment for a message that was successfully processed.""" |
| 188 | + ack_channel = self._ack_channel(session_id) |
| 189 | + await self._redis.publish(ack_channel, message_id) # type: ignore |
| 190 | + logger.debug( |
| 191 | + f"Sent acknowledgment for message {message_id} to session {session_id}" |
| 192 | + ) |
118 | 193 |
|
119 | 194 | async def publish_message( |
120 | | - self, session_id: UUID, message: types.JSONRPCMessage | str |
121 | | - ) -> bool: |
| 195 | + self, |
| 196 | + session_id: UUID, |
| 197 | + message: types.JSONRPCMessage | str, |
| 198 | + message_id: str | None = None, |
| 199 | + ) -> str | None: |
122 | 200 | """Publish a message for the specified session.""" |
123 | 201 | if not await self.session_exists(session_id): |
124 | 202 | logger.warning(f"Message dropped: unknown session {session_id}") |
125 | | - return False |
| 203 | + return None |
126 | 204 |
|
127 | 205 | # Pass raw JSON strings directly, preserving validation errors |
| 206 | + message_id = message_id or str(uuid4()) |
128 | 207 | if isinstance(message, str): |
129 | | - data = message |
| 208 | + wrapper = MessageWrapper(message_id=message_id, payload=message) |
130 | 209 | else: |
131 | | - data = message.model_dump_json() |
| 210 | + wrapper = MessageWrapper( |
| 211 | + message_id=message_id, payload=message.model_dump_json() |
| 212 | + ) |
132 | 213 |
|
133 | 214 | channel = self._session_channel(session_id) |
134 | | - await self._redis.publish(channel, data) # type: ignore[attr-defined] |
135 | | - logger.debug(f"Message published to Redis channel for session {session_id}") |
136 | | - return True |
| 215 | + await self._redis.publish(channel, wrapper.model_dump_json()) # type: ignore |
| 216 | + logger.debug( |
| 217 | + f"Message {message_id} published to Redis channel for session {session_id}" |
| 218 | + ) |
| 219 | + return message_id |
| 220 | + |
| 221 | + async def publish_message_sync( |
| 222 | + self, |
| 223 | + session_id: UUID, |
| 224 | + message: types.JSONRPCMessage | str, |
| 225 | + timeout: float = 120.0, |
| 226 | + ) -> bool: |
| 227 | + """Publish a message and wait for acknowledgment of processing.""" |
| 228 | + message_id = str(uuid4()) |
| 229 | + ack_event = Event() |
| 230 | + self._ack_events[message_id] = ack_event |
| 231 | + |
| 232 | + try: |
| 233 | + published_id = await self.publish_message(session_id, message, message_id) |
| 234 | + if published_id is None: |
| 235 | + return False |
| 236 | + |
| 237 | + with anyio.fail_after(timeout): |
| 238 | + await ack_event.wait() |
| 239 | + logger.debug(f"Received acknowledgment for message {message_id}") |
| 240 | + return True |
| 241 | + |
| 242 | + except TimeoutError: |
| 243 | + logger.warning( |
| 244 | + f"Timed out waiting for acknowledgment of message {message_id}" |
| 245 | + ) |
| 246 | + return False |
| 247 | + |
| 248 | + finally: |
| 249 | + if message_id in self._ack_events: |
| 250 | + del self._ack_events[message_id] |
137 | 251 |
|
138 | 252 | async def session_exists(self, session_id: UUID) -> bool: |
139 | 253 | """Check if a session exists.""" |
|
0 commit comments