Skip to content

Commit 3d4a0d7

Browse files
committed
connections: add race_disconnect() method
This adds a new method for racing other async calls against a disconnect event. This is especially important in the case where we are using async queues. Without this, we would get a deadlock if the device disconnected when we were waiting on a queue because nothing would ever be pushed to the queue after the disconnect.
1 parent d408f33 commit 3d4a0d7

File tree

2 files changed

+68
-14
lines changed

2 files changed

+68
-14
lines changed

CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
66

77
## [Unreleased]
88

9-
## Fixed
9+
### Added
10+
- Added ``PybricksHub.race_disconnect()`` method.
11+
12+
### Fixed
1013
- Fixed race condition with `pybricksdev run ble` not waiting for program to
1114
finish before disconnecting ([pybricksdev#28]).
1215

pybricksdev/connections.py

Lines changed: 64 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
import os
77
import struct
8+
from typing import Awaitable, TypeVar
89

910
import asyncssh
1011
import semver
@@ -15,7 +16,7 @@
1516
from serial import Serial
1617
from tqdm.auto import tqdm
1718
from tqdm.contrib.logging import logging_redirect_tqdm
18-
from rx.subject import Subject, BehaviorSubject
19+
from rx.subject import Subject, BehaviorSubject, AsyncSubject
1920

2021
from .ble.lwp3.bytecodes import HubKind
2122
from .ble.nus import NUS_RX_UUID, NUS_TX_UUID
@@ -34,6 +35,8 @@
3435

3536
logger = logging.getLogger(__name__)
3637

38+
T = TypeVar("T")
39+
3740

3841
class EV3Connection:
3942
"""ev3dev SSH connection for running pybricks-micropython scripts.
@@ -153,6 +156,7 @@ class PybricksHub:
153156
EOL = b"\r\n" # MicroPython EOL
154157

155158
def __init__(self):
159+
self.disconnect_observable = AsyncSubject()
156160
self.status_observable = BehaviorSubject(StatusFlag(0))
157161
self.nus_observable = Subject()
158162
self.stream_buf = bytearray()
@@ -258,14 +262,18 @@ async def connect(self, device: BLEDevice):
258262
Information Service)
259263
RuntimeError: if Pybricks Protocol version is not supported
260264
"""
261-
logger.info(f"Connecting to {device.address}")
262-
self.client = BleakClient(device)
265+
logger.info(f"Connecting to {device.name}")
263266

264-
def disconnected_handler(self, _: BleakClient):
267+
def handle_disconnect(client: BleakClient):
265268
logger.info("Disconnected!")
269+
self.disconnect_observable.on_next(True)
270+
self.disconnect_observable.on_completed()
266271
self.connected = False
267272

268-
await self.client.connect(disconnected_callback=disconnected_handler)
273+
self.client = BleakClient(device, disconnected_callback=handle_disconnect)
274+
275+
await self.client.connect()
276+
269277
try:
270278
logger.info("Connected successfully!")
271279
protocol_version = await self.client.read_gatt_char(SW_REV_UUID)
@@ -298,6 +306,45 @@ async def disconnect(self):
298306
else:
299307
logger.debug("already disconnected")
300308

309+
async def race_disconnect(self, awaitable: Awaitable[T]) -> T:
310+
"""
311+
Races an awaitable against a disconnect event.
312+
313+
If a disconnect event occurs before the awaitable is complete, a
314+
``RuntimeError`` is raised and the awaitable is canceled.
315+
316+
Otherwise, the result of the awaitable is returned. If the awaitable
317+
raises an exception, that exception will be raised.
318+
319+
Args:
320+
awaitable: Any awaitable such as a coroutine.
321+
322+
Returns:
323+
The result of the awaitable.
324+
325+
Raises:
326+
RuntimeError:
327+
Thrown if the hub is disconnected before the awaitable completed.
328+
"""
329+
awaitable_task = asyncio.ensure_future(awaitable)
330+
331+
disconnect_event = asyncio.Event()
332+
disconnect_task = asyncio.ensure_future(disconnect_event.wait())
333+
334+
with self.disconnect_observable.subscribe(lambda _: disconnect_event.set()):
335+
done, pending = await asyncio.wait(
336+
{awaitable_task, disconnect_task},
337+
return_when=asyncio.FIRST_COMPLETED,
338+
)
339+
340+
for t in pending:
341+
t.cancel()
342+
343+
if awaitable_task not in done:
344+
raise RuntimeError("disconnected during operation")
345+
346+
return awaitable_task.result()
347+
301348
async def write(self, data, with_response=False):
302349
await self.client.write_gatt_char(NUS_RX_UUID, bytearray(data), with_response)
303350

@@ -315,7 +362,7 @@ async def run(self, py_path, wait=True, print_output=True):
315362
try:
316363
self.loading = True
317364

318-
queue = asyncio.Queue()
365+
queue: asyncio.Queue[bytes] = asyncio.Queue()
319366
subscription = self.nus_observable.subscribe(
320367
lambda data: queue.put_nowait(data)
321368
)
@@ -338,7 +385,9 @@ async def send_block(data: bytes) -> None:
338385
else:
339386
await self.client.write_gatt_char(NUS_RX_UUID, data, False)
340387

341-
msg: bytes = await asyncio.wait_for(queue.get(), timeout=0.5)
388+
msg = await asyncio.wait_for(
389+
self.race_disconnect(queue.get()), timeout=0.5
390+
)
342391
actual_checksum = msg[0]
343392
expected_checksum = xor_bytes(data, 0)
344393

@@ -363,23 +412,25 @@ async def send_block(data: bytes) -> None:
363412
self.loading = False
364413

365414
if wait:
366-
user_program_running = asyncio.Queue()
415+
user_program_running: asyncio.Queue[bool] = asyncio.Queue()
367416

368417
with self.status_observable.pipe(
369-
op.map(lambda s: s & StatusFlag.USER_PROGRAM_RUNNING),
418+
op.map(lambda s: bool(s & StatusFlag.USER_PROGRAM_RUNNING)),
370419
op.distinct_until_changed(),
371420
).subscribe(lambda s: user_program_running.put_nowait(s)):
372421

373422
# The first item in the queue is the current status. The status
374423
# could change before or after the last checksum is received,
375-
# so this could be truthy or falsy.
376-
is_running = await user_program_running.get()
424+
# so this could be either true or false.
425+
is_running = await self.race_disconnect(user_program_running.get())
377426

378427
if not is_running:
379428
# if the program has not already started, wait a short time
380429
# for it to start
381430
try:
382-
await asyncio.wait_for(user_program_running.get(), 0.2)
431+
await asyncio.wait_for(
432+
self.race_disconnect(user_program_running.get()), 0.2
433+
)
383434
except asyncio.TimeoutError:
384435
# if it doesn't start, assume it was a very short lived
385436
# program and we just missed the status message
@@ -388,7 +439,7 @@ async def send_block(data: bytes) -> None:
388439
# At this point, we know the user program is running, so the
389440
# next item in the queue must indicate that the user program
390441
# has stopped.
391-
is_running = await user_program_running.get()
442+
is_running = await self.race_disconnect(user_program_running.get())
392443

393444
# maybe catch mistake if the code is changed
394445
assert not is_running

0 commit comments

Comments
 (0)