Skip to content

Commit a151f1c

Browse files
committed
clean up comment and session state
1 parent e5938d4 commit a151f1c

File tree

2 files changed

+8
-11
lines changed

2 files changed

+8
-11
lines changed

src/mcp/server/message_queue/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ class InMemoryMessageDispatch:
6666

6767
def __init__(self) -> None:
6868
self._callbacks: dict[UUID, MessageCallback] = {}
69-
# _callbacks tracks active sessions, no need for separate _active_sessions set
7069

7170
async def publish_message(
7271
self, session_id: UUID, message: types.JSONRPCMessage | str

src/mcp/server/message_queue/redis.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ def __init__(
4242
self._pubsub = self._redis.pubsub(ignore_subscribe_messages=True) # type: ignore
4343
self._prefix = prefix
4444
self._active_sessions_key = f"{prefix}active_sessions"
45-
self._callbacks: dict[UUID, MessageCallback] = {}
46-
self._task_groups: dict[UUID, TaskGroup] = {}
45+
# Maps session IDs to the callback and task group for that SSE session.
46+
self._session_state: dict[UUID, tuple[MessageCallback, TaskGroup]] = {}
4747
# Ensures only one polling task runs at a time for message handling
4848
self._limiter = CapacityLimiter(1)
4949
logger.debug(f"Redis message dispatch initialized: {redis_url}")
@@ -56,22 +56,20 @@ def _session_channel(self, session_id: UUID) -> str:
5656
async def subscribe(self, session_id: UUID, callback: MessageCallback):
5757
"""Request-scoped context manager that subscribes to messages for a session."""
5858
await self._redis.sadd(self._active_sessions_key, session_id.hex)
59-
self._callbacks[session_id] = callback
6059
channel = self._session_channel(session_id)
6160
await self._pubsub.subscribe(channel) # type: ignore
6261

6362
logger.debug(f"Subscribing to Redis channel for session {session_id}")
6463
async with anyio.create_task_group() as tg:
65-
self._task_groups[session_id] = tg
64+
self._session_state[session_id] = (callback, tg)
6665
tg.start_soon(self._listen_for_messages)
6766
try:
6867
yield
6968
finally:
7069
tg.cancel_scope.cancel()
7170
await self._pubsub.unsubscribe(channel) # type: ignore
7271
await self._redis.srem(self._active_sessions_key, session_id.hex)
73-
del self._callbacks[session_id]
74-
del self._task_groups[session_id]
72+
del self._session_state[session_id]
7573
logger.debug(f"Unsubscribed from Redis channel: {session_id}")
7674

7775
def _extract_session_id(self, channel: str) -> UUID | None:
@@ -114,8 +112,8 @@ async def _listen_for_messages(self) -> None:
114112

115113
data: str = cast(str, message["data"])
116114
try:
117-
if session_id in self._task_groups:
118-
self._task_groups[session_id].start_soon(
115+
if session_state := self._session_state.get(session_id):
116+
session_state[1].start_soon(
119117
self._handle_message, session_id, data
120118
)
121119
else:
@@ -127,7 +125,7 @@ async def _listen_for_messages(self) -> None:
127125

128126
async def _handle_message(self, session_id: UUID, data: str) -> None:
129127
"""Process a message from Redis in the session's task group."""
130-
if session_id not in self._callbacks:
128+
if (session_state := self._session_state.get(session_id)) is None:
131129
logger.warning(f"Message dropped: callback removed for {session_id}")
132130
return
133131

@@ -139,7 +137,7 @@ async def _handle_message(self, session_id: UUID, data: str) -> None:
139137
except ValidationError as exc:
140138
msg_or_error = exc
141139

142-
await self._callbacks[session_id](msg_or_error)
140+
await session_state[0](msg_or_error)
143141
except Exception as e:
144142
logger.error(f"Error in message handler for {session_id}: {e}")
145143

0 commit comments

Comments
 (0)