Skip to content

Commit 4fa26b3

Browse files
Merge pull request #345 from dvonthenen/fix-issue-344
Fix Starting KeepAlive Always, Switch for Exceptions
2 parents 2d0ddf1 + 163f46d commit 4fa26b3

File tree

12 files changed

+318
-199
lines changed

12 files changed

+318
-199
lines changed

deepgram/clients/live/v1/async_client.py

Lines changed: 132 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ def __init__(self, config: DeepgramClientOptions):
4949
self.config = config
5050
self.endpoint = "v1/listen"
5151
self._socket = None
52+
self._exit_event = None
5253
self._event_handlers = {event: [] for event in LiveTranscriptionEvents}
5354
self.websocket_url = convert_to_websocket_url(self.config.url, self.endpoint)
54-
self.exit_event = None
5555

5656
# starts the WebSocket connection for live transcription
5757
async def start(
@@ -61,6 +61,9 @@ async def start(
6161
members: Optional[Dict] = None,
6262
**kwargs,
6363
) -> bool:
64+
"""
65+
Starts the WebSocket connection for live transcription.
66+
"""
6467
self.logger.debug("AsyncLiveClient.start ENTER")
6568
self.logger.info("options: %s", options)
6669
self.logger.info("addons: %s", addons)
@@ -102,13 +105,25 @@ async def start(
102105
self.logger.debug("combined_options: %s", combined_options)
103106

104107
url_with_params = append_query_params(self.websocket_url, combined_options)
105-
self.exit_event = asyncio.Event()
106108

107109
try:
108-
self._socket = await _socket_connect(url_with_params, self.config.headers)
110+
self._socket = await websockets.connect(
111+
url_with_params,
112+
extra_headers=self.config.headers,
113+
ping_interval=PING_INTERVAL,
114+
)
115+
self._exit_event = asyncio.Event()
109116

117+
# listen thread
110118
self._listen_thread = asyncio.create_task(self._listening())
111-
self._keep_alive_thread = asyncio.create_task(self._keep_alive())
119+
120+
# keepalive thread
121+
if self.config.options.get("keepalive") == "true":
122+
self.logger.notice("keepalive is disabled")
123+
self._keep_alive_thread = asyncio.create_task(self._keep_alive())
124+
else:
125+
self.logger.notice("keepalive is disabled")
126+
self._keep_alive_thread = None
112127

113128
# push open event
114129
await self._emit(
@@ -120,12 +135,30 @@ async def start(
120135
self.logger.debug("AsyncLiveClient.start LEAVE")
121136
return True
122137
except websockets.ConnectionClosed as e:
123-
await self._emit(LiveTranscriptionEvents.Close, e.code)
124-
self.logger.notice("exception: websockets.ConnectionClosed")
138+
self.logger.error("exception: websockets.ConnectionClosed")
139+
self.logger.debug("AsyncLiveClient.start LEAVE")
140+
if self.config.options.get("termination_exception_connect") == "true":
141+
raise
142+
return False
143+
except websockets.exceptions.WebSocketException as e:
144+
self.logger.error("WebSocketException in AsyncLiveClient.start: %s", e)
125145
self.logger.debug("AsyncLiveClient.start LEAVE")
146+
if self.config.options.get("termination_exception_connect") == "true":
147+
raise
148+
return False
149+
except Exception as e:
150+
self.logger.error("WebSocketException in AsyncLiveClient.start: %s", e)
151+
self.logger.debug("AsyncLiveClient.start LEAVE")
152+
if self.config.options.get("termination_exception_connect") == "true":
153+
raise
154+
return False
126155

127156
# registers event handlers for specific events
128157
def on(self, event: LiveTranscriptionEvents, handler) -> None:
158+
"""
159+
Registers event handlers for specific events.
160+
"""
161+
self.logger.info("event fired: %s", event)
129162
if event in LiveTranscriptionEvents and callable(handler):
130163
self._event_handlers[event].append(handler)
131164

@@ -140,7 +173,7 @@ async def _listening(self) -> None:
140173

141174
while True:
142175
try:
143-
if self.exit_event.is_set():
176+
if self._exit_event.is_set():
144177
self.logger.notice("_listening exiting gracefully")
145178
self.logger.debug("AsyncLiveClient._listening LEAVE")
146179
return
@@ -218,6 +251,11 @@ async def _listening(self) -> None:
218251
**dict(self.kwargs),
219252
)
220253
case _:
254+
self.logger.warning(
255+
"Unknown Message: response_type: %s, data: %s",
256+
response_type,
257+
data,
258+
)
221259
error = ErrorResponse(
222260
type="UnhandledMessage",
223261
description="Unknown message type",
@@ -231,6 +269,9 @@ async def _listening(self) -> None:
231269
return
232270

233271
except websockets.exceptions.WebSocketException as e:
272+
self.logger.error(
273+
"WebSocketException in AsyncLiveClient._listening: %s", e
274+
)
234275
error: ErrorResponse = {
235276
"type": "Exception",
236277
"description": "WebSocketException in AsyncLiveClient._listening",
@@ -242,15 +283,17 @@ async def _listening(self) -> None:
242283
)
243284
await self._emit(LiveTranscriptionEvents.Error, error)
244285

286+
# signal exit and close
287+
await self._signal_exit()
288+
245289
self.logger.debug("AsyncLiveClient._listening LEAVE")
246290

247-
if (
248-
"termination_exception" in self.options
249-
and self.options["termination_exception"] == "true"
250-
):
291+
if self.config.options.get("termination_exception") == "true":
251292
raise
293+
return
252294

253295
except Exception as e:
296+
self.logger.error("Exception in AsyncLiveClient._listening: %s", e)
254297
error: ErrorResponse = {
255298
"type": "Exception",
256299
"description": "Exception in AsyncLiveClient._listening",
@@ -260,13 +303,14 @@ async def _listening(self) -> None:
260303
self.logger.error("Exception in AsyncLiveClient._listening: %s", str(e))
261304
await self._emit(LiveTranscriptionEvents.Error, error)
262305

306+
# signal exit and close
307+
await self._signal_exit()
308+
263309
self.logger.debug("AsyncLiveClient._listening LEAVE")
264310

265-
if (
266-
"termination_exception" in self.options
267-
and self.options["termination_exception"] == "true"
268-
):
311+
if self.config.options.get("termination_exception") == "true":
269312
raise
313+
return
270314

271315
# keep the connection alive by sending keepalive messages
272316
async def _keep_alive(self) -> None:
@@ -278,21 +322,18 @@ async def _keep_alive(self) -> None:
278322
counter += 1
279323
await asyncio.sleep(ONE_SECOND)
280324

281-
if self.exit_event.is_set():
325+
if self._exit_event.is_set():
282326
self.logger.notice("_keep_alive exiting gracefully")
283327
self.logger.debug("AsyncLiveClient._keep_alive LEAVE")
284328
return
285329

286330
if self._socket is None:
287331
self.logger.notice("socket is None, exiting keep_alive")
288332
self.logger.debug("AsyncLiveClient._keep_alive LEAVE")
289-
break
333+
return
290334

291335
# deepgram keepalive
292-
if (
293-
counter % DEEPGRAM_INTERVAL == 0
294-
and self.config.options.get("keepalive") == "true"
295-
):
336+
if counter % DEEPGRAM_INTERVAL == 0:
296337
self.logger.verbose("Sending KeepAlive...")
297338
await self.send(json.dumps({"type": "KeepAlive"}))
298339

@@ -302,6 +343,9 @@ async def _keep_alive(self) -> None:
302343
return
303344

304345
except websockets.exceptions.WebSocketException as e:
346+
self.logger.error(
347+
"WebSocketException in AsyncLiveClient._keep_alive: %s", e
348+
)
305349
error: ErrorResponse = {
306350
"type": "Exception",
307351
"description": "WebSocketException in AsyncLiveClient._keep_alive",
@@ -313,16 +357,17 @@ async def _keep_alive(self) -> None:
313357
)
314358
await self._emit(LiveTranscriptionEvents.Error, error)
315359

360+
# signal exit and close
361+
await self._signal_exit()
362+
316363
self.logger.debug("AsyncLiveClient._keep_alive LEAVE")
317364

318-
if (
319-
"termination_exception" in self.options
320-
and self.options["termination_exception"] == "true"
321-
):
365+
if self.config.options.get("termination_exception") == "true":
322366
raise
323367
return
324368

325369
except Exception as e:
370+
self.logger.error("Exception in AsyncLiveClient._keep_alive: %s", e)
326371
error: ErrorResponse = {
327372
"type": "Exception",
328373
"description": "Exception in _keep_alive",
@@ -334,54 +379,99 @@ async def _keep_alive(self) -> None:
334379
)
335380
await self._emit(LiveTranscriptionEvents.Error, error)
336381

382+
# signal exit and close
383+
await self._signal_exit()
384+
337385
self.logger.debug("AsyncLiveClient._keep_alive LEAVE")
338386

339-
if (
340-
"termination_exception" in self.options
341-
and self.options["termination_exception"] == "true"
342-
):
387+
if self.config.options.get("termination_exception") == "true":
343388
raise
344389
return
345390

346-
self.logger.debug("AsyncLiveClient._keep_alive LEAVE")
347-
348391
# sends data over the WebSocket connection
349392
async def send(self, data: Union[str, bytes]) -> bool:
350393
"""
351394
Sends data over the WebSocket connection.
352395
"""
353396
self.logger.spam("AsyncLiveClient.send ENTER")
354397

398+
if self._exit_event.is_set():
399+
self.logger.notice("send exiting gracefully")
400+
self.logger.debug("AsyncLiveClient.send LEAVE")
401+
return False
402+
355403
if self._socket is not None:
356404
try:
357405
await self._socket.send(data)
406+
except websockets.exceptions.ConnectionClosedOK as e:
407+
self.logger.notice(f"send() exiting gracefully: {e.code}")
408+
self.logger.debug("AsyncLiveClient._keep_alive LEAVE")
409+
if self.config.options.get("termination_exception_send") == "true":
410+
raise
411+
return True
358412
except websockets.exceptions.WebSocketException as e:
359413
self.logger.error("send() failed - WebSocketException: %s", str(e))
360414
self.logger.spam("AsyncLiveClient.send LEAVE")
415+
if self.config.options.get("termination_exception_send") == "true":
416+
raise
361417
return False
362418
except Exception as e:
363419
self.logger.error("send() failed - Exception: %s", str(e))
364420
self.logger.spam("AsyncLiveClient.send LEAVE")
421+
if self.config.options.get("termination_exception_send") == "true":
422+
raise
365423
return False
366424

367425
self.logger.spam(f"send() succeeded")
368426
self.logger.spam("AsyncLiveClient.send LEAVE")
369427
return True
370428

371-
self.logger.error("send() failed. socket is None")
429+
self.logger.spam("send() failed. socket is None")
372430
self.logger.spam("AsyncLiveClient.send LEAVE")
373431
return False
374432

433+
# closes the WebSocket connection gracefully
375434
async def finish(self) -> bool:
376435
"""
377436
Closes the WebSocket connection gracefully.
378437
"""
379438
self.logger.debug("AsyncLiveClient.finish ENTER")
380439

381440
# signal exit
382-
self.exit_event.set()
441+
await self._signal_exit()
383442

384-
# close the stream
443+
# stop the threads
444+
self.logger.verbose("cancelling tasks...")
445+
try:
446+
# Before cancelling, check if the tasks were created
447+
tasks = []
448+
if self._keep_alive_thread is not None:
449+
self._keep_alive_thread.cancel()
450+
tasks.append(self._keep_alive_thread)
451+
if self._listen_thread is not None:
452+
self._listen_thread.cancel()
453+
tasks.append(self._listen_thread)
454+
455+
# Use asyncio.gather to wait for tasks to be cancelled
456+
await asyncio.gather(*filter(None, tasks), return_exceptions=True)
457+
self.logger.notice("threads joined")
458+
self._listen_thread = None
459+
self._keep_alive_thread = None
460+
461+
self._socket = None
462+
463+
self.logger.notice("finish succeeded")
464+
self.logger.spam("AsyncLiveClient.finish LEAVE")
465+
return True
466+
467+
except asyncio.CancelledError as e:
468+
self.logger.error("tasks cancelled error: %s", e)
469+
self.logger.debug("AsyncLiveClient.finish LEAVE")
470+
return False
471+
472+
# signals the WebSocket connection to exit
473+
async def _signal_exit(self) -> None:
474+
# send close event
385475
self.logger.verbose("closing socket...")
386476
if self._socket is not None:
387477
self.logger.verbose("send CloseStream...")
@@ -395,36 +485,15 @@ async def finish(self) -> bool:
395485
CloseResponse(type=LiveTranscriptionEvents.Close.value),
396486
)
397487

488+
# signal exit
489+
self._exit_event.set()
490+
491+
# closes the WebSocket connection gracefully
492+
self.logger.verbose("clean up socket...")
493+
if self._socket is not None:
398494
self.logger.verbose("socket.wait_closed...")
399495
try:
400-
await self._socket.wait_closed()
496+
await self._socket.close()
497+
self._socket = None
401498
except websockets.exceptions.WebSocketException as e:
402499
self.logger.error("socket.wait_closed failed: %s", e)
403-
self._socket = None
404-
405-
self.logger.verbose("cancelling tasks...")
406-
try:
407-
# Before cancelling, check if the tasks were created
408-
if self._listen_thread is not None:
409-
self._listen_thread.cancel()
410-
if self._keep_alive_thread is not None:
411-
self._keep_alive_thread.cancel()
412-
413-
# Use asyncio.gather to wait for tasks to be cancelled
414-
tasks = [self._listen_thread, self._keep_alive_thread]
415-
await asyncio.gather(*filter(None, tasks), return_exceptions=True)
416-
417-
except asyncio.CancelledError as e:
418-
self.logger.error("tasks cancelled error: %s", e)
419-
420-
self.logger.info("finish succeeded")
421-
self.logger.debug("AsyncLiveClient.finish LEAVE")
422-
return True
423-
424-
425-
async def _socket_connect(websocket_url, headers) -> websockets.WebSocketClientProtocol:
426-
destination = websocket_url
427-
updated_headers = headers
428-
return await websockets.connect(
429-
destination, extra_headers=updated_headers, ping_interval=PING_INTERVAL
430-
)

0 commit comments

Comments
 (0)