Skip to content

Commit 6472635

Browse files
author
Alan Christie
committed
fix: Better close for websockets (close rather than exception)
1 parent afef082 commit 6472635

File tree

1 file changed

+22
-26
lines changed

1 file changed

+22
-26
lines changed

app/app.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,10 @@ async def event_stream(
180180
before sending the poison pill.
181181
"""
182182

183+
_LOGGER.info("Accepting connection (uuid=%s)...", uuid)
184+
await websocket.accept()
185+
_LOGGER.info("Accepted connection (uuid=%s)", uuid)
186+
183187
# Custom request headers.
184188
# The following are used to identify the first event in a stream: -
185189
#
@@ -237,13 +241,14 @@ async def event_stream(
237241
# Replace any error with a 'too many values provided' error if necessary
238242
if num_stream_from_specified > 1:
239243
header_value_error = True
240-
header_value_error_msg = "Cannot provide more than one X-StreamFrom value"
244+
header_value_error_msg = "Cannot provide more than one X-StreamFrom variable"
241245

242246
if header_value_error:
243-
raise HTTPException(
244-
status_code=status.HTTP_400_BAD_REQUEST,
245-
detail=header_value_error_msg,
247+
await websocket.close(
248+
code=status.WS_1002_PROTOCOL_ERROR,
249+
reason=header_value_error_msg,
246250
)
251+
return
247252

248253
# Get the DB record for this UUID...
249254
_LOGGER.debug("Connect attempt (uuid=%s)...", uuid)
@@ -255,28 +260,18 @@ async def event_stream(
255260
if not es:
256261
msg: str = f"Connect for unknown EventStream {uuid}"
257262
_LOGGER.warning(msg)
258-
raise HTTPException(
259-
status_code=status.HTTP_404_NOT_FOUND,
260-
detail=msg,
261-
)
263+
await websocket.close(code=status.WS_1000_NORMAL_CLOSURE, reason=msg)
264+
return
262265

263266
# Get the ID (for diagnostics)
264267
# and the routing key for the queue...
265268
es_id = es[0]
266269
routing_key: str = es[2]
267270

268-
_LOGGER.info(
269-
"Waiting for 'accept' on stream %s (uuid=%s routing_key='%s')...",
270-
es_id,
271-
uuid,
272-
routing_key,
273-
)
274-
await websocket.accept()
275-
_LOGGER.info("Accepted connection for %s", es_id)
276-
277271
_LOGGER.debug(
278-
"Creating Consumer for %s (%s:%s@%s/%s)...",
272+
"Creating Consumer for %s [%s] (%s:%s@%s/%s)...",
279273
es_id,
274+
routing_key,
280275
_AMPQ_USERNAME,
281276
_AMPQ_PASSWORD,
282277
_AMPQ_HOSTNAME,
@@ -289,14 +284,17 @@ async def event_stream(
289284
vhost=_AMPQ_VHOST,
290285
load_balancer_mode=True,
291286
)
287+
# Before continuing ... does the stream exist?
288+
# If we don't check it now we'll fail later anyway.
289+
# The AS is expected to create and delete the streams.
292290
if not await consumer.stream_exists(routing_key):
293291
msg: str = f"EventStream {uuid} cannot be found"
294292
_LOGGER.warning(msg)
295-
raise HTTPException(
296-
status_code=status.HTTP_404_NOT_FOUND,
297-
detail=msg,
298-
)
293+
await websocket.close(code=status.WS_1013_TRY_AGAIN_LATER, reason=msg)
294+
return
299295

296+
# Start consuming the stream.
297+
# We don't return from here until there's an error.
300298
_LOGGER.info("Consuming %s...", es_id)
301299
await _consume(
302300
consumer=consumer,
@@ -306,12 +304,11 @@ async def event_stream(
306304
offset_specification=offset_specification,
307305
)
308306

307+
# One our way out...
309308
await websocket.close(
310309
code=status.WS_1000_NORMAL_CLOSURE, reason="The stream has been deleted"
311310
)
312-
_LOGGER.info("Closed WebSocket for %s", es_id)
313-
314-
_LOGGER.info("Disconnected %s (uuid=%s)...", es_id, uuid)
311+
_LOGGER.info("Closed WebSocket for %s (uuid=%s)", es_id, uuid)
315312

316313

317314
async def generate_on_message_for_websocket(websocket: WebSocket, es_id: str):
@@ -347,7 +344,6 @@ async def on_message_for_websocket(
347344
)
348345

349346
shutdown: bool = False
350-
# decoded_msg: str = msg.decode(encoding="utf-8")
351347
if msg == b"POISON":
352348
_LOGGER.info("Taking POISON for %s (stopping)...", es_id)
353349
shutdown = True

0 commit comments

Comments
 (0)