Skip to content

Commit f8147e8

Browse files
author
Alan Christie
committed
feat: Adds shutdown logic to ws app service
1 parent 2646342 commit f8147e8

File tree

1 file changed

+39
-1
lines changed

1 file changed

+39
-1
lines changed

app/app.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import os
66
import sqlite3
7+
import threading
78
from logging.config import dictConfig
89
from typing import Any
910

@@ -86,6 +87,14 @@ def _get_location(uuid: str) -> str:
8687
_ES[2],
8788
)
8889

90+
# The active websocket connections (and a thread lock).
91+
# A set of WebSocket objects appended to when a connection is made
92+
# and where objects are removed when each the connection is closed.
93+
# This is used by the "shutdown" event to gracefully close all
94+
# active connections.
95+
_ACTIVE_CONNECTIONS: set[WebSocket] = set()
96+
_ACTIVE_CONNECTIONS_LOCK = threading.Lock()
97+
8998

9099
# We use pydantic to declare the model (request payloads) for the internal REST API.
91100
# The public API is a WebSocket API and does not require a model.
@@ -116,6 +125,26 @@ class EventStreamGetResponse(BaseModel):
116125
event_streams: list[EventStreamItem]
117126

118127

128+
def _add_connection(websocket: WebSocket):
129+
"""Safely add a connection to the active connections set."""
130+
with _ACTIVE_CONNECTIONS_LOCK:
131+
_ACTIVE_CONNECTIONS.add(websocket)
132+
133+
134+
def _remove_connection(websocket: WebSocket):
135+
"""Safely remove a connection from the active connections set."""
136+
with _ACTIVE_CONNECTIONS_LOCK:
137+
_ACTIVE_CONNECTIONS.remove(websocket)
138+
139+
140+
@app_public.on_event("shutdown")
141+
async def shutdown():
142+
"""The application is shutting down.
143+
Gracefully close all active connections."""
144+
for websocket in _ACTIVE_CONNECTIONS:
145+
await websocket.close(code=status.WS_1001_GOING_AWAY, reason="Server shutdown")
146+
147+
119148
# Endpoints for the 'public-facing' event-stream web-socket API ------------------------
120149

121150

@@ -158,7 +187,13 @@ async def event_stream(websocket: WebSocket, uuid: str):
158187
routing_key,
159188
)
160189
await websocket.accept()
161-
_LOGGER.info("Accepted connection for %s", es_id)
190+
# Add us to the set of active connections
191+
_add_connection(websocket)
192+
_LOGGER.info(
193+
"Accepted connection for %s (%s active connections)",
194+
es_id,
195+
len(_ACTIVE_CONNECTIONS),
196+
)
162197

163198
_LOGGER.debug("Creating reader for %s...", es_id)
164199
message_reader = _get_from_queue(routing_key)
@@ -178,6 +213,9 @@ async def event_stream(websocket: WebSocket, uuid: str):
178213
else:
179214
await websocket.send_text(str(message_body))
180215

216+
# Remove us from the set of active connections
217+
_remove_connection(websocket)
218+
181219
_LOGGER.info("Closing %s (uuid=%s)...", es_id, uuid)
182220
await websocket.close(
183221
code=status.WS_1000_NORMAL_CLOSURE, reason="The stream has been deleted"

0 commit comments

Comments
 (0)