@@ -180,6 +180,10 @@ async def event_stream(
180180 before sending the poison pill.
181181 """
182182
183+ _LOGGER .info ("Accepting connection (uuid=%s)..." , uuid )
184+ await websocket .accept ()
185+ _LOGGER .info ("Accepted connection (uuid=%s)" , uuid )
186+
183187 # Custom request headers.
184188 # The following are used to identify the first event in a stream: -
185189 #
@@ -237,13 +241,14 @@ async def event_stream(
237241 # Replace any error with a 'too many values provided' error if necessary
238242 if num_stream_from_specified > 1 :
239243 header_value_error = True
240- header_value_error_msg = "Cannot provide more than one X-StreamFrom value "
244+ header_value_error_msg = "Cannot provide more than one X-StreamFrom variable "
241245
242246 if header_value_error :
243- raise HTTPException (
244- status_code = status .HTTP_400_BAD_REQUEST ,
245- detail = header_value_error_msg ,
247+ await websocket . close (
248+ code = status .WS_1002_PROTOCOL_ERROR ,
249+ reason = header_value_error_msg ,
246250 )
251+ return
247252
248253 # Get the DB record for this UUID...
249254 _LOGGER .debug ("Connect attempt (uuid=%s)..." , uuid )
@@ -255,28 +260,18 @@ async def event_stream(
255260 if not es :
256261 msg : str = f"Connect for unknown EventStream { uuid } "
257262 _LOGGER .warning (msg )
258- raise HTTPException (
259- status_code = status .HTTP_404_NOT_FOUND ,
260- detail = msg ,
261- )
263+ await websocket .close (code = status .WS_1000_NORMAL_CLOSURE , reason = msg )
264+ return
262265
263266 # Get the ID (for diagnostics)
264267 # and the routing key for the queue...
265268 es_id = es [0 ]
266269 routing_key : str = es [2 ]
267270
268- _LOGGER .info (
269- "Waiting for 'accept' on stream %s (uuid=%s routing_key='%s')..." ,
270- es_id ,
271- uuid ,
272- routing_key ,
273- )
274- await websocket .accept ()
275- _LOGGER .info ("Accepted connection for %s" , es_id )
276-
277271 _LOGGER .debug (
278- "Creating Consumer for %s (%s:%s@%s/%s)..." ,
272+ "Creating Consumer for %s [%s] (%s:%s@%s/%s)..." ,
279273 es_id ,
274+ routing_key ,
280275 _AMPQ_USERNAME ,
281276 _AMPQ_PASSWORD ,
282277 _AMPQ_HOSTNAME ,
@@ -289,14 +284,17 @@ async def event_stream(
289284 vhost = _AMPQ_VHOST ,
290285 load_balancer_mode = True ,
291286 )
287+ # Before continuing ... does the stream exist?
288+ # If we don't check it now we'll fail later anyway.
289+ # The AS is expected to create and delete the streams.
292290 if not await consumer .stream_exists (routing_key ):
293291 msg : str = f"EventStream { uuid } cannot be found"
294292 _LOGGER .warning (msg )
295- raise HTTPException (
296- status_code = status .HTTP_404_NOT_FOUND ,
297- detail = msg ,
298- )
293+ await websocket .close (code = status .WS_1013_TRY_AGAIN_LATER , reason = msg )
294+ return
299295
296+ # Start consuming the stream.
297+ # We don't return from here until there's an error.
300298 _LOGGER .info ("Consuming %s..." , es_id )
301299 await _consume (
302300 consumer = consumer ,
@@ -306,12 +304,11 @@ async def event_stream(
306304 offset_specification = offset_specification ,
307305 )
308306
307+ # One our way out...
309308 await websocket .close (
310309 code = status .WS_1000_NORMAL_CLOSURE , reason = "The stream has been deleted"
311310 )
312- _LOGGER .info ("Closed WebSocket for %s" , es_id )
313-
314- _LOGGER .info ("Disconnected %s (uuid=%s)..." , es_id , uuid )
311+ _LOGGER .info ("Closed WebSocket for %s (uuid=%s)" , es_id , uuid )
315312
316313
317314async def generate_on_message_for_websocket (websocket : WebSocket , es_id : str ):
@@ -347,7 +344,6 @@ async def on_message_for_websocket(
347344 )
348345
349346 shutdown : bool = False
350- # decoded_msg: str = msg.decode(encoding="utf-8")
351347 if msg == b"POISON" :
352348 _LOGGER .info ("Taking POISON for %s (stopping)..." , es_id )
353349 shutdown = True
0 commit comments