Skip to content

Commit 0bfd800

Browse files
committed
Rework to fix POST blocking issue
1 parent b484284 commit 0bfd800

File tree

3 files changed

+234
-67
lines changed

3 files changed

+234
-67
lines changed

src/mcp/server/message_queue/base.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Protocol, runtime_checkable
55
from uuid import UUID
66

7-
from pydantic import ValidationError
7+
from pydantic import BaseModel, ValidationError
88

99
import mcp.types as types
1010

@@ -13,6 +13,18 @@
1313
MessageCallback = Callable[[types.JSONRPCMessage | Exception], Awaitable[None]]
1414

1515

16+
class MessageWrapper(BaseModel):
17+
message_id: str
18+
payload: str
19+
20+
def get_json_rpc_message(self) -> types.JSONRPCMessage | ValidationError:
21+
"""Parse the payload into a JSONRPCMessage or return ValidationError."""
22+
try:
23+
return types.JSONRPCMessage.model_validate_json(self.payload)
24+
except ValidationError as exc:
25+
return exc
26+
27+
1628
@runtime_checkable
1729
class MessageDispatch(Protocol):
1830
"""Abstract interface for SSE message dispatching.
@@ -35,6 +47,25 @@ async def publish_message(
3547
"""
3648
...
3749

50+
async def publish_message_sync(
51+
self, session_id: UUID, message: types.JSONRPCMessage | str, timeout: float = 30.0
52+
) -> bool:
53+
"""Publish a message for the specified session and wait for consumption confirmation.
54+
55+
This method blocks until the message has been fully consumed by the subscriber,
56+
or until the timeout is reached.
57+
58+
Args:
59+
session_id: The UUID of the session this message is for
60+
message: The message to publish (JSONRPCMessage or str for invalid JSON)
61+
timeout: Maximum time to wait for consumption in seconds
62+
63+
Returns:
64+
bool: True if message was published and consumed, False otherwise
65+
"""
66+
# Default implementation falls back to standard publish
67+
return await self.publish_message(session_id, message)
68+
3869
@asynccontextmanager
3970
async def subscribe(self, session_id: UUID, callback: MessageCallback):
4071
"""Request-scoped context manager that subscribes to messages for a session.
@@ -90,6 +121,18 @@ async def publish_message(
90121

91122
logger.debug(f"Message dispatched to session {session_id}")
92123
return True
124+
125+
async def publish_message_sync(
126+
self, session_id: UUID, message: types.JSONRPCMessage | str, timeout: float = 30.0
127+
) -> bool:
128+
"""Publish a message for the specified session and wait for consumption.
129+
130+
For InMemoryMessageDispatch, this is the same as publish_message since
131+
the callback is executed synchronously.
132+
"""
133+
# For in-memory dispatch, the message is processed immediately
134+
# so we can just call the regular publish method
135+
return await self.publish_message(session_id, message)
93136

94137
@asynccontextmanager
95138
async def subscribe(self, session_id: UUID, callback: MessageCallback):

src/mcp/server/message_queue/redis.py

Lines changed: 177 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import logging
22
from contextlib import asynccontextmanager
33
from typing import Any, cast
4-
from uuid import UUID
4+
from uuid import UUID, uuid4
55

66
import anyio
7-
from anyio import CapacityLimiter, lowlevel
8-
from pydantic import ValidationError
7+
from anyio import CancelScope, CapacityLimiter, Event, lowlevel
8+
from anyio.abc import TaskGroup
99

1010
import mcp.types as types
11-
from mcp.server.message_queue.base import MessageCallback
11+
from mcp.server.message_queue.base import MessageCallback, MessageWrapper
1212

1313
try:
1414
import redis.asyncio as redis
@@ -42,98 +42,212 @@ def __init__(
4242
self._prefix = prefix
4343
self._active_sessions_key = f"{prefix}active_sessions"
4444
self._callbacks: dict[UUID, MessageCallback] = {}
45-
# Ensures only one polling task runs at a time for message handling
4645
self._limiter = CapacityLimiter(1)
46+
self._ack_events: dict[str, Event] = {}
47+
4748
logger.debug(f"Redis message dispatch initialized: {redis_url}")
4849

4950
def _session_channel(self, session_id: UUID) -> str:
5051
"""Get the Redis channel for a session."""
5152
return f"{self._prefix}session:{session_id.hex}"
5253

54+
def _ack_channel(self, session_id: UUID) -> str:
55+
"""Get the acknowledgment channel for a session."""
56+
return f"{self._prefix}ack:{session_id.hex}"
57+
5358
@asynccontextmanager
5459
async def subscribe(self, session_id: UUID, callback: MessageCallback):
5560
"""Request-scoped context manager that subscribes to messages for a session."""
5661
await self._redis.sadd(self._active_sessions_key, session_id.hex)
5762
self._callbacks[session_id] = callback
58-
channel = self._session_channel(session_id)
59-
await self._pubsub.subscribe(channel) # type: ignore
60-
61-
logger.debug(f"Subscribing to Redis channel for session {session_id}")
62-
async with anyio.create_task_group() as tg:
63-
tg.start_soon(self._listen_for_messages)
64-
try:
65-
yield
66-
finally:
67-
tg.cancel_scope.cancel()
68-
await self._pubsub.unsubscribe(channel) # type: ignore
69-
await self._redis.srem(self._active_sessions_key, session_id.hex)
70-
del self._callbacks[session_id]
71-
logger.debug(f"Unsubscribed from Redis channel: {session_id}")
72-
73-
async def _listen_for_messages(self) -> None:
63+
64+
session_channel = self._session_channel(session_id)
65+
ack_channel = self._ack_channel(session_id)
66+
67+
await self._pubsub.subscribe(session_channel) # type: ignore
68+
await self._pubsub.subscribe(ack_channel) # type: ignore
69+
70+
logger.debug(f"Subscribing to Redis channels for session {session_id}")
71+
72+
# Two nested task groups ensure proper cleanup: the inner one cancels the
73+
# listener, while the outer one allows any handlers to complete before exiting.
74+
async with anyio.create_task_group() as tg_handler:
75+
async with anyio.create_task_group() as tg:
76+
tg.start_soon(self._listen_for_messages, tg_handler)
77+
try:
78+
yield
79+
finally:
80+
tg.cancel_scope.cancel()
81+
await self._pubsub.unsubscribe(session_channel) # type: ignore
82+
await self._pubsub.unsubscribe(ack_channel) # type: ignore
83+
await self._redis.srem(self._active_sessions_key, session_id.hex)
84+
del self._callbacks[session_id]
85+
logger.debug(
86+
f"Unsubscribed from Redis channels for session {session_id}"
87+
)
88+
89+
async def _listen_for_messages(self, tg_handler: TaskGroup) -> None:
7490
"""Background task that listens for messages on subscribed channels."""
7591
async with self._limiter:
7692
while True:
7793
await lowlevel.checkpoint()
78-
message: None | dict[str, Any] = await self._pubsub.get_message( # type: ignore
79-
ignore_subscribe_messages=True,
80-
timeout=None, # type: ignore
81-
)
82-
if message is None:
83-
continue
84-
85-
channel: str = cast(str, message["channel"])
86-
expected_prefix = f"{self._prefix}session:"
94+
# Shield message retrieval from cancellation to ensure no messages are
95+
# lost when a session disconnects during processing.
96+
with CancelScope(shield=True):
97+
redis_message: ( # type: ignore
98+
None | dict[str, Any]
99+
) = await self._pubsub.get_message( # type: ignore
100+
ignore_subscribe_messages=True,
101+
timeout=0.1, # type: ignore
102+
)
103+
if redis_message is None:
104+
continue
87105

88-
if not channel.startswith(expected_prefix):
89-
logger.debug(f"Ignoring message from non-MCP channel: {channel}")
90-
continue
106+
channel: str = cast(str, redis_message["channel"])
107+
data: str = cast(str, redis_message["data"])
91108

92-
session_hex = channel[len(expected_prefix) :]
93-
try:
94-
session_id = UUID(hex=session_hex)
95-
expected_channel = self._session_channel(session_id)
96-
if channel != expected_channel:
97-
logger.error(f"Channel format mismatch: {channel}")
109+
# Handle acknowledgment messages
110+
if channel.startswith(f"{self._prefix}ack:"):
111+
tg_handler.start_soon(self._handle_ack_message, channel, data)
98112
continue
99-
except ValueError:
100-
logger.error(f"Invalid UUID in channel: {channel}")
101-
continue
102113

103-
data: str = cast(str, message["data"])
104-
try:
105-
if session_id not in self._callbacks:
106-
logger.warning(f"Message dropped: no callback for {session_id}")
114+
# Handle session messages
115+
elif channel.startswith(f"{self._prefix}session:"):
116+
tg_handler.start_soon(
117+
self._handle_session_message, channel, data
118+
)
107119
continue
108120

109-
# Try to parse as valid message or recreate original ValidationError
110-
try:
111-
msg = types.JSONRPCMessage.model_validate_json(data)
112-
await self._callbacks[session_id](msg)
113-
except ValidationError as exc:
114-
# Pass the identical validation error that would have occurred
115-
await self._callbacks[session_id](exc)
116-
except Exception as e:
117-
logger.error(f"Error processing message for {session_id}: {e}")
121+
# Ignore other channels
122+
else:
123+
logger.debug(
124+
f"Ignoring message from non-MCP channel: {channel}"
125+
)
126+
127+
async def _handle_ack_message(self, channel: str, data: str) -> None:
128+
"""Handle acknowledgment messages received on ack channels."""
129+
ack_prefix = f"{self._prefix}ack:"
130+
if not channel.startswith(ack_prefix):
131+
return
132+
133+
# Validate channel format exactly matches our expected format
134+
session_hex = channel[len(ack_prefix) :]
135+
try:
136+
# Validate this is a valid UUID hex and channel has correct format
137+
session_id = UUID(hex=session_hex)
138+
expected_channel = self._ack_channel(session_id)
139+
if channel != expected_channel:
140+
logger.error(
141+
f"Channel mismatch: got {channel}, expected {expected_channel}"
142+
)
143+
return
144+
except ValueError:
145+
logger.error(f"Invalid UUID hex in ack channel: {channel}")
146+
return
147+
148+
# Extract message ID from data
149+
message_id = data.strip()
150+
if message_id in self._ack_events:
151+
logger.debug(f"Received acknowledgment for message: {message_id}")
152+
self._ack_events[message_id].set()
153+
154+
async def _handle_session_message(self, channel: str, data: str) -> None:
155+
"""Handle regular messages received on session channels."""
156+
session_prefix = f"{self._prefix}session:"
157+
if not channel.startswith(session_prefix):
158+
return
159+
160+
session_hex = channel[len(session_prefix) :]
161+
try:
162+
session_id = UUID(hex=session_hex)
163+
expected_channel = self._session_channel(session_id)
164+
if channel != expected_channel:
165+
logger.error(
166+
f"Channel mismatch: got {channel}, expected {expected_channel}"
167+
)
168+
return
169+
except ValueError:
170+
logger.error(f"Invalid UUID hex in session channel: {channel}")
171+
return
172+
173+
if session_id not in self._callbacks:
174+
logger.warning(f"Message dropped: no callback for {session_id}")
175+
return
176+
177+
try:
178+
wrapper = MessageWrapper.model_validate_json(data)
179+
result = wrapper.get_json_rpc_message()
180+
await self._callbacks[session_id](result)
181+
await self._send_acknowledgment(session_id, wrapper.message_id)
182+
183+
except Exception as e:
184+
logger.error(f"Error processing message for {session_id}: {e}")
185+
186+
async def _send_acknowledgment(self, session_id: UUID, message_id: str) -> None:
187+
"""Send an acknowledgment for a message that was successfully processed."""
188+
ack_channel = self._ack_channel(session_id)
189+
await self._redis.publish(ack_channel, message_id) # type: ignore
190+
logger.debug(
191+
f"Sent acknowledgment for message {message_id} to session {session_id}"
192+
)
118193

119194
async def publish_message(
120-
self, session_id: UUID, message: types.JSONRPCMessage | str
121-
) -> bool:
195+
self,
196+
session_id: UUID,
197+
message: types.JSONRPCMessage | str,
198+
message_id: str | None = None,
199+
) -> str | None:
122200
"""Publish a message for the specified session."""
123201
if not await self.session_exists(session_id):
124202
logger.warning(f"Message dropped: unknown session {session_id}")
125-
return False
203+
return None
126204

127205
# Pass raw JSON strings directly, preserving validation errors
206+
message_id = message_id or str(uuid4())
128207
if isinstance(message, str):
129-
data = message
208+
wrapper = MessageWrapper(message_id=message_id, payload=message)
130209
else:
131-
data = message.model_dump_json()
210+
wrapper = MessageWrapper(
211+
message_id=message_id, payload=message.model_dump_json()
212+
)
132213

133214
channel = self._session_channel(session_id)
134-
await self._redis.publish(channel, data) # type: ignore[attr-defined]
135-
logger.debug(f"Message published to Redis channel for session {session_id}")
136-
return True
215+
await self._redis.publish(channel, wrapper.model_dump_json()) # type: ignore
216+
logger.debug(
217+
f"Message {message_id} published to Redis channel for session {session_id}"
218+
)
219+
return message_id
220+
221+
async def publish_message_sync(
222+
self,
223+
session_id: UUID,
224+
message: types.JSONRPCMessage | str,
225+
timeout: float = 120.0,
226+
) -> bool:
227+
"""Publish a message and wait for acknowledgment of processing."""
228+
message_id = str(uuid4())
229+
ack_event = Event()
230+
self._ack_events[message_id] = ack_event
231+
232+
try:
233+
published_id = await self.publish_message(session_id, message, message_id)
234+
if published_id is None:
235+
return False
236+
237+
with anyio.fail_after(timeout):
238+
await ack_event.wait()
239+
logger.debug(f"Received acknowledgment for message {message_id}")
240+
return True
241+
242+
except TimeoutError:
243+
logger.warning(
244+
f"Timed out waiting for acknowledgment of message {message_id}"
245+
)
246+
return False
247+
248+
finally:
249+
if message_id in self._ack_events:
250+
del self._ack_events[message_id]
137251

138252
async def session_exists(self, session_id: UUID) -> bool:
139253
"""Check if a session exists."""

src/mcp/server/sse.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,14 +174,24 @@ async def handle_post_message(
174174
logger.debug(f"Validated client message: {message}")
175175
except ValidationError as err:
176176
logger.error(f"Failed to parse message: {err}")
177+
# Still publish the invalid message, but using synchronized version
177178
response = Response("Could not parse message", status_code=400)
178179
await response(scope, receive, send)
179180
# Pass raw JSON string; receiver will recreate identical ValidationError
180181
# when parsing the same invalid JSON
181-
await self._message_dispatch.publish_message(session_id, body.decode())
182+
await self._message_dispatch.publish_message_sync(session_id, body.decode())
182183
return
183184

184185
logger.debug(f"Publishing message for session {session_id}: {message}")
185-
response = Response("Accepted", status_code=202)
186+
187+
# Use sync publish to block until the message is processed
188+
result = await self._message_dispatch.publish_message_sync(session_id, message)
189+
190+
if result:
191+
# Message was successfully processed
192+
response = Response("OK", status_code=200)
193+
else:
194+
# Message timed out or failed to be processed
195+
response = Response("Processing Timeout", status_code=504)
196+
186197
await response(scope, receive, send)
187-
await self._message_dispatch.publish_message(session_id, message)

0 commit comments

Comments
 (0)