Skip to content

Commit 48d7204

Browse files
author
Alan Christie
committed
refactor: Connections now handle b'POISON' message (which closes any open connection)
1 parent b08e670 commit 48d7204

File tree

1 file changed

+18
-45
lines changed

1 file changed

+18
-45
lines changed

app/app.py

Lines changed: 18 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _get_location(uuid: str) -> str:
8080
_DB_CONNECTION.close()
8181
for _ES in _EVENT_STREAMS:
8282
_LOGGER.info(
83-
"Existing EventStream: %s (id=%s routing_key=%s)",
83+
"Existing EventStream: %s (id=%s routing_key='%s')",
8484
_get_location(_ES[1]),
8585
_ES[0],
8686
_ES[2],
@@ -116,37 +116,20 @@ class EventStreamGetResponse(BaseModel):
116116
event_streams: list[EventStreamItem]
117117

118118

119-
# Active websocket connections, indexed by Event Stream record ID
120-
_ACTIVE_CONNECTIONS: dict[int, list[WebSocket]] = {}
121-
122-
123-
def _add_active_connection(es_id: int, websocket: WebSocket) -> int:
124-
"""Add an active connection to the list of active connections,
125-
returning the current number of connections for this ID."""
126-
if es_id not in _ACTIVE_CONNECTIONS:
127-
_ACTIVE_CONNECTIONS[es_id] = []
128-
_ACTIVE_CONNECTIONS[es_id].append(websocket)
129-
return len(_ACTIVE_CONNECTIONS[es_id])
130-
131-
132-
def _get_active_connections(es_id: int) -> list[WebSocket]:
133-
"""Get the list of active connections for a given Event Stream record ID."""
134-
return _ACTIVE_CONNECTIONS.get(es_id, [])
135-
136-
137-
def _forget_active_connections(es_id: int) -> None:
138-
"""Removes all active connections for a given Event Stream record ID."""
139-
del _ACTIVE_CONNECTIONS[es_id]
140-
141-
142119
# Endpoints for the 'public-facing' event-stream web-socket API ------------------------
143120

144121

145122
@app_public.websocket("/event-stream/{uuid}")
146123
async def event_stream(websocket: WebSocket, uuid: str):
147124
"""The websocket handler for the event-stream.
148-
The actual location is returned to the AS when the web-socket is created
149-
using a POST to /event-stream/."""
125+
The UUID is returned to the AS when the web-socket is created
126+
using a POST to /event-stream/.
127+
128+
The socket will close if a 'POISON' message is received.
129+
The AS will insert one of these into the stream after it is has been closed.
130+
i.e. the API will call us to close the connection (removing the record from our DB)
131+
before sending the poison pill.
132+
"""
150133

151134
# Get the DB record for this UUID...
152135
_LOGGER.debug("Connect attempt (uuid=%s)...", uuid)
@@ -169,14 +152,13 @@ async def event_stream(websocket: WebSocket, uuid: str):
169152
routing_key: str = es[2]
170153

171154
_LOGGER.debug(
172-
"Waiting for 'accept' on stream %s (uuid=%s routing_key=%s)...",
155+
"Waiting for 'accept' on stream %s (uuid=%s routing_key='%s')...",
173156
es_id,
174157
uuid,
175158
routing_key,
176159
)
177160
await websocket.accept()
178-
count_for_this_id = _add_active_connection(es_id, websocket)
179-
_LOGGER.debug("Accepted connection for %s (active=%s)", es_id, count_for_this_id)
161+
_LOGGER.debug("Accepted connection for %s", es_id)
180162

181163
_LOGGER.debug("Creating reader for %s...", es_id)
182164
message_reader = _get_from_queue(routing_key)
@@ -190,7 +172,11 @@ async def event_stream(websocket: WebSocket, uuid: str):
190172
reader = anext(message_reader)
191173
message_body = await reader
192174
_LOGGER.debug("Got message for %s (message_body=%s)", es_id, message_body)
193-
await websocket.send_text(str(message_body))
175+
if message_body == b"POISON":
176+
_LOGGER.debug("Got poison pill for %s (%s) (closing)...", es_id, uuid)
177+
_running = False
178+
else:
179+
await websocket.send_text(str(message_body))
194180

195181
_LOGGER.debug("Leaving %s (uid=%s)...", es_id, uuid)
196182

@@ -240,7 +226,7 @@ def post_es(request_body: EventStreamPostRequestBody) -> EventStreamPostResponse
240226
# we just need to provide a UUID and the routing key.
241227
routing_key: str = request_body.routing_key
242228
_LOGGER.info(
243-
"Creating new event stream %s (routing_key=%s)...", uuid_str, routing_key
229+
"Creating new event stream %s (routing_key='%s')...", uuid_str, routing_key
244230
)
245231

246232
db = sqlite3.connect(_DATABASE_PATH)
@@ -301,7 +287,7 @@ def delete_es(es_id: int):
301287
)
302288

303289
_LOGGER.info(
304-
"Deleting event stream %s (uuid=%s routing_key=%s)", es_id, es[1], es[2]
290+
"Deleting event stream %s (uuid=%s routing_key='%s')", es_id, es[1], es[2]
305291
)
306292

307293
# Delete the ES record...
@@ -312,17 +298,4 @@ def delete_es(es_id: int):
312298
db.commit()
313299
db.close()
314300

315-
# Now close (and erase) any existing connections...
316-
# See https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1 for status codes
317-
# The reason ius limited to 123 utf-8 bytes
318-
active_websockets = _get_active_connections(es_id)
319-
if active_websockets:
320-
_LOGGER.info(
321-
"Closing active connections for %s (%s)", es_id, len(active_websockets)
322-
)
323-
for websocket in _get_active_connections(es_id):
324-
websocket.close(code=1000, reason="Event Stream deleted")
325-
_LOGGER.info("Closed active connections for %s", es_id)
326-
_forget_active_connections(es_id)
327-
328301
_LOGGER.info("Deleted %s", es_id)

0 commit comments

Comments
 (0)