@@ -80,7 +80,7 @@ def _get_location(uuid: str) -> str:
8080 _DB_CONNECTION .close ()
8181 for _ES in _EVENT_STREAMS :
8282 _LOGGER .info (
83- "Existing EventStream: %s (id=%s routing_key=%s )" ,
83+ "Existing EventStream: %s (id=%s routing_key='%s' )" ,
8484 _get_location (_ES [1 ]),
8585 _ES [0 ],
8686 _ES [2 ],
@@ -116,37 +116,20 @@ class EventStreamGetResponse(BaseModel):
116116 event_streams : list [EventStreamItem ]
117117
118118
119- # Active websocket connections, indexed by Event Stream record ID
120- _ACTIVE_CONNECTIONS : dict [int , list [WebSocket ]] = {}
121-
122-
123- def _add_active_connection (es_id : int , websocket : WebSocket ) -> int :
124- """Add an active connection to the list of active connections,
125- returning the current number of connections for this ID."""
126- if es_id not in _ACTIVE_CONNECTIONS :
127- _ACTIVE_CONNECTIONS [es_id ] = []
128- _ACTIVE_CONNECTIONS [es_id ].append (websocket )
129- return len (_ACTIVE_CONNECTIONS [es_id ])
130-
131-
132- def _get_active_connections (es_id : int ) -> list [WebSocket ]:
133- """Get the list of active connections for a given Event Stream record ID."""
134- return _ACTIVE_CONNECTIONS .get (es_id , [])
135-
136-
137- def _forget_active_connections (es_id : int ) -> None :
138- """Removes all active connections for a given Event Stream record ID."""
139- del _ACTIVE_CONNECTIONS [es_id ]
140-
141-
142119# Endpoints for the 'public-facing' event-stream web-socket API ------------------------
143120
144121
145122@app_public .websocket ("/event-stream/{uuid}" )
146123async def event_stream (websocket : WebSocket , uuid : str ):
147124 """The websocket handler for the event-stream.
148- The actual location is returned to the AS when the web-socket is created
149- using a POST to /event-stream/."""
125+ The UUID is returned to the AS when the web-socket is created
126+ using a POST to /event-stream/.
127+
128+ The socket will close if a 'POISON' message is received.
129+ The AS will insert one of these into the stream after it is has been closed.
130+ i.e. the API will call us to close the connection (removing the record from our DB)
131+ before sending the poison pill.
132+ """
150133
151134 # Get the DB record for this UUID...
152135 _LOGGER .debug ("Connect attempt (uuid=%s)..." , uuid )
@@ -169,14 +152,13 @@ async def event_stream(websocket: WebSocket, uuid: str):
169152 routing_key : str = es [2 ]
170153
171154 _LOGGER .debug (
172- "Waiting for 'accept' on stream %s (uuid=%s routing_key=%s )..." ,
155+ "Waiting for 'accept' on stream %s (uuid=%s routing_key='%s' )..." ,
173156 es_id ,
174157 uuid ,
175158 routing_key ,
176159 )
177160 await websocket .accept ()
178- count_for_this_id = _add_active_connection (es_id , websocket )
179- _LOGGER .debug ("Accepted connection for %s (active=%s)" , es_id , count_for_this_id )
161+ _LOGGER .debug ("Accepted connection for %s" , es_id )
180162
181163 _LOGGER .debug ("Creating reader for %s..." , es_id )
182164 message_reader = _get_from_queue (routing_key )
@@ -190,7 +172,11 @@ async def event_stream(websocket: WebSocket, uuid: str):
190172 reader = anext (message_reader )
191173 message_body = await reader
192174 _LOGGER .debug ("Got message for %s (message_body=%s)" , es_id , message_body )
193- await websocket .send_text (str (message_body ))
175+ if message_body == b"POISON" :
176+ _LOGGER .debug ("Got poison pill for %s (%s) (closing)..." , es_id , uuid )
177+ _running = False
178+ else :
179+ await websocket .send_text (str (message_body ))
194180
195181 _LOGGER .debug ("Leaving %s (uid=%s)..." , es_id , uuid )
196182
@@ -240,7 +226,7 @@ def post_es(request_body: EventStreamPostRequestBody) -> EventStreamPostResponse
240226 # we just need to provide a UUID and the routing key.
241227 routing_key : str = request_body .routing_key
242228 _LOGGER .info (
243- "Creating new event stream %s (routing_key=%s )..." , uuid_str , routing_key
229+ "Creating new event stream %s (routing_key='%s' )..." , uuid_str , routing_key
244230 )
245231
246232 db = sqlite3 .connect (_DATABASE_PATH )
@@ -301,7 +287,7 @@ def delete_es(es_id: int):
301287 )
302288
303289 _LOGGER .info (
304- "Deleting event stream %s (uuid=%s routing_key=%s )" , es_id , es [1 ], es [2 ]
290+ "Deleting event stream %s (uuid=%s routing_key='%s' )" , es_id , es [1 ], es [2 ]
305291 )
306292
307293 # Delete the ES record...
@@ -312,17 +298,4 @@ def delete_es(es_id: int):
312298 db .commit ()
313299 db .close ()
314300
315- # Now close (and erase) any existing connections...
316- # See https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1 for status codes
317- # The reason ius limited to 123 utf-8 bytes
318- active_websockets = _get_active_connections (es_id )
319- if active_websockets :
320- _LOGGER .info (
321- "Closing active connections for %s (%s)" , es_id , len (active_websockets )
322- )
323- for websocket in _get_active_connections (es_id ):
324- websocket .close (code = 1000 , reason = "Event Stream deleted" )
325- _LOGGER .info ("Closed active connections for %s" , es_id )
326- _forget_active_connections (es_id )
327-
328301 _LOGGER .info ("Deleted %s" , es_id )
0 commit comments