55import logging
66import os
77import struct
8+ from typing import Awaitable , TypeVar
89
910import asyncssh
1011import semver
1516from serial import Serial
1617from tqdm .auto import tqdm
1718from tqdm .contrib .logging import logging_redirect_tqdm
18- from rx .subject import Subject , BehaviorSubject
19+ from rx .subject import Subject , BehaviorSubject , AsyncSubject
1920
2021from .ble .lwp3 .bytecodes import HubKind
2122from .ble .nus import NUS_RX_UUID , NUS_TX_UUID
3435
3536logger = logging .getLogger (__name__ )
3637
38+ T = TypeVar ("T" )
39+
3740
3841class EV3Connection :
3942 """ev3dev SSH connection for running pybricks-micropython scripts.
@@ -153,6 +156,7 @@ class PybricksHub:
153156 EOL = b"\r \n " # MicroPython EOL
154157
155158 def __init__ (self ):
159+ self .disconnect_observable = AsyncSubject ()
156160 self .status_observable = BehaviorSubject (StatusFlag (0 ))
157161 self .nus_observable = Subject ()
158162 self .stream_buf = bytearray ()
@@ -258,14 +262,18 @@ async def connect(self, device: BLEDevice):
258262 Information Service)
259263 RuntimeError: if Pybricks Protocol version is not supported
260264 """
261- logger .info (f"Connecting to { device .address } " )
262- self .client = BleakClient (device )
265+ logger .info (f"Connecting to { device .name } " )
263266
264- def disconnected_handler ( self , _ : BleakClient ):
267+ def handle_disconnect ( client : BleakClient ):
265268 logger .info ("Disconnected!" )
269+ self .disconnect_observable .on_next (True )
270+ self .disconnect_observable .on_completed ()
266271 self .connected = False
267272
268- await self .client .connect (disconnected_callback = disconnected_handler )
273+ self .client = BleakClient (device , disconnected_callback = handle_disconnect )
274+
275+ await self .client .connect ()
276+
269277 try :
270278 logger .info ("Connected successfully!" )
271279 protocol_version = await self .client .read_gatt_char (SW_REV_UUID )
@@ -298,6 +306,45 @@ async def disconnect(self):
298306 else :
299307 logger .debug ("already disconnected" )
300308
309+ async def race_disconnect (self , awaitable : Awaitable [T ]) -> T :
310+ """
311+ Races an awaitable against a disconnect event.
312+
313+ If a disconnect event occurs before the awaitable is complete, a
314+ ``RuntimeError`` is raised and the awaitable is canceled.
315+
316+ Otherwise, the result of the awaitable is returned. If the awaitable
317+ raises an exception, that exception will be raised.
318+
319+ Args:
320+ awaitable: Any awaitable such as a coroutine.
321+
322+ Returns:
323+ The result of the awaitable.
324+
325+ Raises:
326+ RuntimeError:
327+ Thrown if the hub is disconnected before the awaitable completed.
328+ """
329+ awaitable_task = asyncio .ensure_future (awaitable )
330+
331+ disconnect_event = asyncio .Event ()
332+ disconnect_task = asyncio .ensure_future (disconnect_event .wait ())
333+
334+ with self .disconnect_observable .subscribe (lambda _ : disconnect_event .set ()):
335+ done , pending = await asyncio .wait (
336+ {awaitable_task , disconnect_task },
337+ return_when = asyncio .FIRST_COMPLETED ,
338+ )
339+
340+ for t in pending :
341+ t .cancel ()
342+
343+ if awaitable_task not in done :
344+ raise RuntimeError ("disconnected during operation" )
345+
346+ return awaitable_task .result ()
347+
301348 async def write (self , data , with_response = False ):
302349 await self .client .write_gatt_char (NUS_RX_UUID , bytearray (data ), with_response )
303350
@@ -315,7 +362,7 @@ async def run(self, py_path, wait=True, print_output=True):
315362 try :
316363 self .loading = True
317364
318- queue = asyncio .Queue ()
365+ queue : asyncio . Queue [ bytes ] = asyncio .Queue ()
319366 subscription = self .nus_observable .subscribe (
320367 lambda data : queue .put_nowait (data )
321368 )
@@ -338,7 +385,9 @@ async def send_block(data: bytes) -> None:
338385 else :
339386 await self .client .write_gatt_char (NUS_RX_UUID , data , False )
340387
341- msg : bytes = await asyncio .wait_for (queue .get (), timeout = 0.5 )
388+ msg = await asyncio .wait_for (
389+ self .race_disconnect (queue .get ()), timeout = 0.5
390+ )
342391 actual_checksum = msg [0 ]
343392 expected_checksum = xor_bytes (data , 0 )
344393
@@ -363,23 +412,25 @@ async def send_block(data: bytes) -> None:
363412 self .loading = False
364413
365414 if wait :
366- user_program_running = asyncio .Queue ()
415+ user_program_running : asyncio . Queue [ bool ] = asyncio .Queue ()
367416
368417 with self .status_observable .pipe (
369- op .map (lambda s : s & StatusFlag .USER_PROGRAM_RUNNING ),
418+ op .map (lambda s : bool ( s & StatusFlag .USER_PROGRAM_RUNNING ) ),
370419 op .distinct_until_changed (),
371420 ).subscribe (lambda s : user_program_running .put_nowait (s )):
372421
373422 # The first item in the queue is the current status. The status
374423 # could change before or after the last checksum is received,
375- # so this could be truthy or falsy .
376- is_running = await user_program_running .get ()
424+ # so this could be either true or false .
425+ is_running = await self . race_disconnect ( user_program_running .get () )
377426
378427 if not is_running :
379428 # if the program has not already started, wait a short time
380429 # for it to start
381430 try :
382- await asyncio .wait_for (user_program_running .get (), 0.2 )
431+ await asyncio .wait_for (
432+ self .race_disconnect (user_program_running .get ()), 0.2
433+ )
383434 except asyncio .TimeoutError :
384435 # if it doesn't start, assume it was a very short lived
385436 # program and we just missed the status message
@@ -388,7 +439,7 @@ async def send_block(data: bytes) -> None:
388439 # At this point, we know the user program is running, so the
389440 # next item in the queue must indicate that the user program
390441 # has stopped.
391- is_running = await user_program_running .get ()
442+ is_running = await self . race_disconnect ( user_program_running .get () )
392443
393444 # maybe catch mistake if the code is changed
394445 assert not is_running
0 commit comments