@@ -42,8 +42,8 @@ def __init__(
4242 self ._pubsub = self ._redis .pubsub (ignore_subscribe_messages = True ) # type: ignore
4343 self ._prefix = prefix
4444 self ._active_sessions_key = f"{ prefix } active_sessions"
45- self . _callbacks : dict [ UUID , MessageCallback ] = {}
46- self ._task_groups : dict [UUID , TaskGroup ] = {}
45+ # Maps session IDs to the callback and task group for that SSE session.
46+ self ._session_state : dict [UUID , tuple [ MessageCallback , TaskGroup ] ] = {}
4747 # Ensures only one polling task runs at a time for message handling
4848 self ._limiter = CapacityLimiter (1 )
4949 logger .debug (f"Redis message dispatch initialized: { redis_url } " )
@@ -56,22 +56,20 @@ def _session_channel(self, session_id: UUID) -> str:
5656 async def subscribe (self , session_id : UUID , callback : MessageCallback ):
5757 """Request-scoped context manager that subscribes to messages for a session."""
5858 await self ._redis .sadd (self ._active_sessions_key , session_id .hex )
59- self ._callbacks [session_id ] = callback
6059 channel = self ._session_channel (session_id )
6160 await self ._pubsub .subscribe (channel ) # type: ignore
6261
6362 logger .debug (f"Subscribing to Redis channel for session { session_id } " )
6463 async with anyio .create_task_group () as tg :
65- self ._task_groups [session_id ] = tg
64+ self ._session_state [session_id ] = ( callback , tg )
6665 tg .start_soon (self ._listen_for_messages )
6766 try :
6867 yield
6968 finally :
7069 tg .cancel_scope .cancel ()
7170 await self ._pubsub .unsubscribe (channel ) # type: ignore
7271 await self ._redis .srem (self ._active_sessions_key , session_id .hex )
73- del self ._callbacks [session_id ]
74- del self ._task_groups [session_id ]
72+ del self ._session_state [session_id ]
7573 logger .debug (f"Unsubscribed from Redis channel: { session_id } " )
7674
7775 def _extract_session_id (self , channel : str ) -> UUID | None :
@@ -114,8 +112,8 @@ async def _listen_for_messages(self) -> None:
114112
115113 data : str = cast (str , message ["data" ])
116114 try :
117- if session_id in self ._task_groups :
118- self . _task_groups [ session_id ].start_soon (
115+ if session_state := self ._session_state . get ( session_id ) :
116+ session_state [ 1 ].start_soon (
119117 self ._handle_message , session_id , data
120118 )
121119 else :
@@ -127,7 +125,7 @@ async def _listen_for_messages(self) -> None:
127125
128126 async def _handle_message (self , session_id : UUID , data : str ) -> None :
129127 """Process a message from Redis in the session's task group."""
130- if session_id not in self ._callbacks :
128+ if ( session_state := self ._session_state . get ( session_id )) is None :
131129 logger .warning (f"Message dropped: callback removed for { session_id } " )
132130 return
133131
@@ -139,7 +137,7 @@ async def _handle_message(self, session_id: UUID, data: str) -> None:
139137 except ValidationError as exc :
140138 msg_or_error = exc
141139
142- await self . _callbacks [ session_id ](msg_or_error )
140+ await session_state [ 0 ](msg_or_error )
143141 except Exception as e :
144142 logger .error (f"Error in message handler for { session_id } : { e } " )
145143
0 commit comments