|
5 | 5 | import os |
6 | 6 | import sqlite3 |
7 | 7 | from logging.config import dictConfig |
8 | | -from typing import Any, Dict, List |
| 8 | +from typing import Any |
9 | 9 |
|
10 | 10 | import aio_pika |
11 | 11 | import shortuuid |
|
14 | 14 |
|
15 | 15 | # Configure logging |
16 | 16 | print("Configuring logging...") |
17 | | -_LOGGING_CONFIG: Dict[str, Any] = {} |
| 17 | +_LOGGING_CONFIG: dict[str, Any] = {} |
18 | 18 | with open("logging.config", "r", encoding="utf8") as stream: |
19 | 19 | try: |
20 | 20 | _LOGGING_CONFIG = json.loads(stream.read()) |
@@ -113,7 +113,30 @@ class EventStreamItem(BaseModel): |
113 | 113 | class EventStreamGetResponse(BaseModel): |
114 | 114 | """/event-stream/ POST response.""" |
115 | 115 |
|
116 | | - event_streams: List[EventStreamItem] |
| 116 | + event_streams: list[EventStreamItem] |
| 117 | + |
| 118 | + |
| 119 | +# Active websocket connections, indexed by Event Stream record ID |
| 120 | +_ACTIVE_CONNECTIONS: dict[int, list[WebSocket]] = {} |
| 121 | + |
| 122 | + |
| 123 | +def _add_active_connection(es_id: int, websocket: WebSocket) -> int: |
| 124 | + """Add an active connection to the list of active connections, |
| 125 | + returning the current number of connections for this ID.""" |
| 126 | + if es_id not in _ACTIVE_CONNECTIONS: |
| 127 | + _ACTIVE_CONNECTIONS[es_id] = [] |
| 128 | + _ACTIVE_CONNECTIONS[es_id].append(websocket) |
| 129 | + return len(_ACTIVE_CONNECTIONS[es_id]) |
| 130 | + |
| 131 | + |
| 132 | +def _get_active_connections(es_id: int) -> list[WebSocket]: |
| 133 | + """Get the list of active connections for a given Event Stream record ID.""" |
| 134 | + return _ACTIVE_CONNECTIONS.get(es_id, []) |
| 135 | + |
| 136 | + |
| 137 | +def _forget_active_connections(es_id: int) -> None: |
| 138 | + """Removes all active connections for a given Event Stream record ID.""" |
| 139 | + del _ACTIVE_CONNECTIONS[es_id] |
117 | 140 |
|
118 | 141 |
|
119 | 142 | # Endpoints for the 'public-facing' event-stream web-socket API ------------------------ |
@@ -152,7 +175,8 @@ async def event_stream(websocket: WebSocket, uuid: str): |
152 | 175 | routing_key, |
153 | 176 | ) |
154 | 177 | await websocket.accept() |
155 | | - _LOGGER.debug("Accepted connection for %s", es_id) |
| 178 | + count_for_this_id = _add_active_connection(es_id, websocket) |
| 179 | + _LOGGER.debug("Accepted connection for %s (active=%s)", es_id, count_for_this_id) |
156 | 180 |
|
157 | 181 | _LOGGER.debug("Creating reader for %s...", es_id) |
158 | 182 | message_reader = _get_from_queue(routing_key) |
@@ -247,7 +271,7 @@ def get_es() -> EventStreamGetResponse: |
247 | 271 | all_es = cursor.execute("SELECT * FROM es").fetchall() |
248 | 272 | db.close() |
249 | 273 |
|
250 | | - event_streams: List[EventStreamItem] = [] |
| 274 | + event_streams: list[EventStreamItem] = [] |
251 | 275 | for es in all_es: |
252 | 276 | location: str = _get_location(es[1]) |
253 | 277 | event_streams.append( |
@@ -281,10 +305,24 @@ def delete_es(es_id: int): |
281 | 305 | ) |
282 | 306 |
|
283 | 307 | # Delete the ES record... |
| 308 | + # This will prevent any further connections. |
284 | 309 | db = sqlite3.connect(_DATABASE_PATH) |
285 | 310 | cursor = db.cursor() |
286 | 311 | cursor.execute(f"DELETE FROM es WHERE id={es_id}") |
287 | 312 | db.commit() |
288 | 313 | db.close() |
289 | 314 |
|
| 315 | + # Now close (and erase) any existing connections... |
| 316 | + # See https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1 for status codes |
| 317 | + # The reason ius limited to 123 utf-8 bytes |
| 318 | + active_websockets = _get_active_connections(es_id) |
| 319 | + if active_websockets: |
| 320 | + _LOGGER.info( |
| 321 | + "Closing active connections for %s (%s)", es_id, len(active_websockets) |
| 322 | + ) |
| 323 | + for websocket in _get_active_connections(es_id): |
| 324 | + websocket.close(code=1000, reason="Event Stream deleted") |
| 325 | + _LOGGER.info("Closed active connections for %s", es_id) |
| 326 | + _forget_active_connections(es_id) |
| 327 | + |
290 | 328 | _LOGGER.info("Deleted %s", es_id) |
0 commit comments