44import logging
55import os
66import sqlite3
7+ import threading
78from logging .config import dictConfig
89from typing import Any
910
@@ -86,6 +87,14 @@ def _get_location(uuid: str) -> str:
8687 _ES [2 ],
8788 )
8889
90+ # The active websocket connections (and a thread lock).
91+ # A set of WebSocket objects appended to when a connection is made
92+ # and where objects are removed when each the connection is closed.
93+ # This is used by the "shutdown" event to gracefully close all
94+ # active connections.
95+ _ACTIVE_CONNECTIONS : set [WebSocket ] = set ()
96+ _ACTIVE_CONNECTIONS_LOCK = threading .Lock ()
97+
8998
9099# We use pydantic to declare the model (request payloads) for the internal REST API.
91100# The public API is a WebSocket API and does not require a model.
@@ -116,6 +125,26 @@ class EventStreamGetResponse(BaseModel):
116125 event_streams : list [EventStreamItem ]
117126
118127
128+ def _add_connection (websocket : WebSocket ):
129+ """Safely add a connection to the active connections set."""
130+ with _ACTIVE_CONNECTIONS_LOCK :
131+ _ACTIVE_CONNECTIONS .add (websocket )
132+
133+
134+ def _remove_connection (websocket : WebSocket ):
135+ """Safely remove a connection from the active connections set."""
136+ with _ACTIVE_CONNECTIONS_LOCK :
137+ _ACTIVE_CONNECTIONS .remove (websocket )
138+
139+
140+ @app_public .on_event ("shutdown" )
141+ async def shutdown ():
142+ """The application is shutting down.
143+ Gracefully close all active connections."""
144+ for websocket in _ACTIVE_CONNECTIONS :
145+ await websocket .close (code = status .WS_1001_GOING_AWAY , reason = "Server shutdown" )
146+
147+
119148# Endpoints for the 'public-facing' event-stream web-socket API ------------------------
120149
121150
@@ -158,7 +187,13 @@ async def event_stream(websocket: WebSocket, uuid: str):
158187 routing_key ,
159188 )
160189 await websocket .accept ()
161- _LOGGER .info ("Accepted connection for %s" , es_id )
190+ # Add us to the set of active connections
191+ _add_connection (websocket )
192+ _LOGGER .info (
193+ "Accepted connection for %s (%s active connections)" ,
194+ es_id ,
195+ len (_ACTIVE_CONNECTIONS ),
196+ )
162197
163198 _LOGGER .debug ("Creating reader for %s..." , es_id )
164199 message_reader = _get_from_queue (routing_key )
@@ -178,6 +213,9 @@ async def event_stream(websocket: WebSocket, uuid: str):
178213 else :
179214 await websocket .send_text (str (message_body ))
180215
216+ # Remove us from the set of active connections
217+ _remove_connection (websocket )
218+
181219 _LOGGER .info ("Closing %s (uuid=%s)..." , es_id , uuid )
182220 await websocket .close (
183221 code = status .WS_1000_NORMAL_CLOSURE , reason = "The stream has been deleted"
0 commit comments