Skip to content

Commit 6771bd5

Browse files
author
Alan Christie
committed
feat: Attempt at initial (simple) socket close on delete
1 parent c87da6d commit 6771bd5

File tree

1 file changed

+43
-5
lines changed

1 file changed

+43
-5
lines changed

app/app.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66
import sqlite3
77
from logging.config import dictConfig
8-
from typing import Any, Dict, List
8+
from typing import Any
99

1010
import aio_pika
1111
import shortuuid
@@ -14,7 +14,7 @@
1414

1515
# Configure logging
1616
print("Configuring logging...")
17-
_LOGGING_CONFIG: Dict[str, Any] = {}
17+
_LOGGING_CONFIG: dict[str, Any] = {}
1818
with open("logging.config", "r", encoding="utf8") as stream:
1919
try:
2020
_LOGGING_CONFIG = json.loads(stream.read())
@@ -113,7 +113,30 @@ class EventStreamItem(BaseModel):
113113
class EventStreamGetResponse(BaseModel):
114114
"""/event-stream/ POST response."""
115115

116-
event_streams: List[EventStreamItem]
116+
event_streams: list[EventStreamItem]
117+
118+
119+
# Active websocket connections, indexed by Event Stream record ID
120+
_ACTIVE_CONNECTIONS: dict[int, list[WebSocket]] = {}
121+
122+
123+
def _add_active_connection(es_id: int, websocket: WebSocket) -> int:
124+
"""Add an active connection to the list of active connections,
125+
returning the current number of connections for this ID."""
126+
if es_id not in _ACTIVE_CONNECTIONS:
127+
_ACTIVE_CONNECTIONS[es_id] = []
128+
_ACTIVE_CONNECTIONS[es_id].append(websocket)
129+
return len(_ACTIVE_CONNECTIONS[es_id])
130+
131+
132+
def _get_active_connections(es_id: int) -> list[WebSocket]:
133+
"""Get the list of active connections for a given Event Stream record ID."""
134+
return _ACTIVE_CONNECTIONS.get(es_id, [])
135+
136+
137+
def _forget_active_connections(es_id: int) -> None:
138+
"""Removes all active connections for a given Event Stream record ID."""
139+
del _ACTIVE_CONNECTIONS[es_id]
117140

118141

119142
# Endpoints for the 'public-facing' event-stream web-socket API ------------------------
@@ -152,7 +175,8 @@ async def event_stream(websocket: WebSocket, uuid: str):
152175
routing_key,
153176
)
154177
await websocket.accept()
155-
_LOGGER.debug("Accepted connection for %s", es_id)
178+
count_for_this_id = _add_active_connection(es_id, websocket)
179+
_LOGGER.debug("Accepted connection for %s (active=%s)", es_id, count_for_this_id)
156180

157181
_LOGGER.debug("Creating reader for %s...", es_id)
158182
message_reader = _get_from_queue(routing_key)
@@ -247,7 +271,7 @@ def get_es() -> EventStreamGetResponse:
247271
all_es = cursor.execute("SELECT * FROM es").fetchall()
248272
db.close()
249273

250-
event_streams: List[EventStreamItem] = []
274+
event_streams: list[EventStreamItem] = []
251275
for es in all_es:
252276
location: str = _get_location(es[1])
253277
event_streams.append(
@@ -281,10 +305,24 @@ def delete_es(es_id: int):
281305
)
282306

283307
# Delete the ES record...
308+
# This will prevent any further connections.
284309
db = sqlite3.connect(_DATABASE_PATH)
285310
cursor = db.cursor()
286311
cursor.execute(f"DELETE FROM es WHERE id={es_id}")
287312
db.commit()
288313
db.close()
289314

315+
# Now close (and erase) any existing connections...
316+
# See https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1 for status codes
317+
# The reason ius limited to 123 utf-8 bytes
318+
active_websockets = _get_active_connections(es_id)
319+
if active_websockets:
320+
_LOGGER.info(
321+
"Closing active connections for %s (%s)", es_id, len(active_websockets)
322+
)
323+
for websocket in _get_active_connections(es_id):
324+
websocket.close(code=1000, reason="Event Stream deleted")
325+
_LOGGER.info("Closed active connections for %s", es_id)
326+
_forget_active_connections(es_id)
327+
290328
_LOGGER.info("Deleted %s", es_id)

0 commit comments

Comments
 (0)