Skip to content

Commit 465c70f

Browse files
committed
pybricksdev.connections.pybricks.PybricksHub: fix connection state management
PybricksHub.disconnect_observable was one-use only since it used an AsyncSubject. This caused the PybricksHub.race_disconnect method to think that the hub was still disconnected after a reconnection. To fix this, we replace PybricksHub.disconnect_observable and PybricksHub.connected with a new PybricksHub.connection_state_observable that uses a BehaviorSubject so that it can be updated across multiple connect/disconnect events. This also introduces a new pybricksdev.connections.ConnectionState enum that give a more nuanced state indication rather than using a bool. Fixes: pybricks/support#971
1 parent 898155c commit 465c70f

File tree

3 files changed

+81
-24
lines changed

3 files changed

+81
-24
lines changed

CHANGELOG.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
66

77
## [Unreleased]
88

9+
### Added
10+
- Added `pybricksdev.connections.ConnectionState` enum class.
11+
- Added `pybricksdev.connections.pybricks.PybricksHub.connection_state_observable` attribute.
12+
13+
### Fixed
14+
- Fixed `pybricksdev.connections.pybricks.PybricksHub` disconnect state not reset after reconnect ([support#971]).
15+
16+
### Removed
17+
- Removed `pybricksdev.connections.pybricks.PybricksHub.disconnect_observable` attribute.
18+
- Removed `pybricksdev.connections.pybricks.PybricksHub.connected` attribute.
19+
920
## [1.0.0-alpha.37] - 2023-02-27
1021

1122
### Added
Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,27 @@
11
# SPDX-License-Identifier: MIT
2-
# Copyright (c) 2022 The Pybricks Authors
2+
# Copyright (c) 2023 The Pybricks Authors
3+
4+
import enum
5+
6+
7+
class ConnectionState(enum.Enum):
8+
"""
9+
Indicates state of a connection.
10+
"""
11+
12+
CONNECTING = enum.auto()
13+
"""
14+
The device is in the process of connecting.
15+
"""
16+
CONNECTED = enum.auto()
17+
"""
18+
The device is connected.
19+
"""
20+
DISCONNECTING = enum.auto()
21+
"""
22+
The device is in the process of disconnecting.
23+
"""
24+
DISCONNECTED = enum.auto()
25+
"""
26+
The device is disconnected.
27+
"""

pybricksdev/connections/pybricks.py

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# SPDX-License-Identifier: MIT
2-
# Copyright (c) 2021-2022 The Pybricks Authors
2+
# Copyright (c) 2021-2023 The Pybricks Authors
33

44
import asyncio
5+
import contextlib
56
import logging
67
import os
78
import struct
@@ -12,7 +13,7 @@
1213
from bleak import BleakClient
1314
from bleak.backends.device import BLEDevice
1415
from packaging.version import Version
15-
from rx.subject import AsyncSubject, BehaviorSubject, Subject
16+
from rx.subject import BehaviorSubject, Subject
1617
from tqdm.auto import tqdm
1718
from tqdm.contrib.logging import logging_redirect_tqdm
1819

@@ -35,6 +36,7 @@
3536
from ..compile import compile_file, compile_multi_file
3637
from ..tools import chunk
3738
from ..tools.checksum import xor_bytes
39+
from . import ConnectionState
3840

3941
logger = logging.getLogger(__name__)
4042

@@ -76,7 +78,7 @@ class PybricksHub:
7678
"""
7779

7880
def __init__(self):
79-
self.disconnect_observable = AsyncSubject()
81+
self.connection_state_observable = BehaviorSubject(ConnectionState.DISCONNECTED)
8082
self.status_observable = BehaviorSubject(StatusFlag(0))
8183
self.nus_observable = Subject()
8284
self.stream_buf = bytearray()
@@ -87,9 +89,6 @@ def __init__(self):
8789
self._capability_flags = HubCapabilityFlag(0)
8890
self._max_user_program_size = 0
8991

90-
# indicates that the hub is currently connected via BLE
91-
self.connected = False
92-
9392
# indicates is we are currently downloading a program over NUS (legacy download)
9493
self._downloading_via_nus = False
9594

@@ -188,17 +187,28 @@ async def connect(self, device: BLEDevice):
188187
"""
189188
logger.info(f"Connecting to {device.name}")
190189

191-
def handle_disconnect(client: BleakClient):
192-
logger.info("Disconnected!")
193-
self.disconnect_observable.on_next(True)
194-
self.disconnect_observable.on_completed()
195-
self.connected = False
190+
if self.connection_state_observable.value != ConnectionState.DISCONNECTED:
191+
raise RuntimeError(
192+
f"attempting to connect with invalid state: {self.connection_state_observable.value}"
193+
)
196194

197-
self.client = BleakClient(device, disconnected_callback=handle_disconnect)
195+
async with contextlib.AsyncExitStack() as stack:
196+
self.connection_state_observable.on_next(ConnectionState.CONNECTING)
198197

199-
await self.client.connect()
198+
stack.callback(
199+
self.connection_state_observable.on_next, ConnectionState.DISCONNECTED
200+
)
201+
202+
def handle_disconnect(_: BleakClient):
203+
logger.info("Disconnected!")
204+
self.connection_state_observable.on_next(ConnectionState.DISCONNECTED)
205+
206+
self.client = BleakClient(device, disconnected_callback=handle_disconnect)
207+
208+
await self.client.connect()
209+
210+
stack.push_async_callback(self.disconnect)
200211

201-
try:
202212
logger.info("Connected successfully!")
203213

204214
fw_version = await self.client.read_gatt_char(FW_REV_UUID)
@@ -236,17 +246,24 @@ def handle_disconnect(client: BleakClient):
236246
await self.client.start_notify(
237247
PYBRICKS_COMMAND_EVENT_UUID, self.pybricks_service_handler
238248
)
239-
self.connected = True
240-
except: # noqa: E722
241-
self.disconnect()
242-
raise
249+
250+
self.connection_state_observable.on_next(ConnectionState.CONNECTED)
251+
252+
# don't unwind on success
253+
stack.pop_all()
243254

244255
async def disconnect(self):
245-
if self.connected:
246-
logger.info("Disconnecting...")
256+
logger.info("Disconnecting...")
257+
258+
if self.connection_state_observable.value == ConnectionState.CONNECTED:
259+
self.connection_state_observable.on_next(ConnectionState.DISCONNECTING)
247260
await self.client.disconnect()
261+
# ConnectionState.DISCONNECTED should be set by disconnect callback
262+
assert (
263+
self.connection_state_observable.value == ConnectionState.DISCONNECTED
264+
)
248265
else:
249-
logger.debug("already disconnected")
266+
logger.debug("skipping disconnect because not connected")
250267

251268
async def race_disconnect(self, awaitable: Awaitable[T]) -> T:
252269
"""
@@ -273,7 +290,11 @@ async def race_disconnect(self, awaitable: Awaitable[T]) -> T:
273290
disconnect_event = asyncio.Event()
274291
disconnect_task = asyncio.ensure_future(disconnect_event.wait())
275292

276-
with self.disconnect_observable.subscribe(lambda _: disconnect_event.set()):
293+
def handle_disconnect(state: ConnectionState):
294+
if state == ConnectionState.DISCONNECTED:
295+
disconnect_event.set()
296+
297+
with self.connection_state_observable.subscribe(handle_disconnect):
277298
done, pending = await asyncio.wait(
278299
{awaitable_task, disconnect_task},
279300
return_when=asyncio.FIRST_COMPLETED,
@@ -301,7 +322,7 @@ async def run(
301322
wait: If true, wait for the user program to stop before returning.
302323
print_output: If true, echo stdout of the hub to ``sys.stdout``.
303324
"""
304-
if not self.connected:
325+
if self.connection_state_observable.value != ConnectionState.CONNECTED:
305326
raise RuntimeError("not connected")
306327

307328
# Reset output buffer

0 commit comments

Comments
 (0)