11# SPDX-License-Identifier: MIT
2- # Copyright (c) 2021-2022 The Pybricks Authors
2+ # Copyright (c) 2021-2023 The Pybricks Authors
33
44import asyncio
5+ import contextlib
56import logging
67import os
78import struct
1213from bleak import BleakClient
1314from bleak .backends .device import BLEDevice
1415from packaging .version import Version
15- from rx .subject import AsyncSubject , BehaviorSubject , Subject
16+ from rx .subject import BehaviorSubject , Subject
1617from tqdm .auto import tqdm
1718from tqdm .contrib .logging import logging_redirect_tqdm
1819
3536from ..compile import compile_file , compile_multi_file
3637from ..tools import chunk
3738from ..tools .checksum import xor_bytes
39+ from . import ConnectionState
3840
3941logger = logging .getLogger (__name__ )
4042
@@ -76,7 +78,7 @@ class PybricksHub:
7678 """
7779
7880 def __init__ (self ):
79- self .disconnect_observable = AsyncSubject ( )
81+ self .connection_state_observable = BehaviorSubject ( ConnectionState . DISCONNECTED )
8082 self .status_observable = BehaviorSubject (StatusFlag (0 ))
8183 self .nus_observable = Subject ()
8284 self .stream_buf = bytearray ()
@@ -87,9 +89,6 @@ def __init__(self):
8789 self ._capability_flags = HubCapabilityFlag (0 )
8890 self ._max_user_program_size = 0
8991
90- # indicates that the hub is currently connected via BLE
91- self .connected = False
92-
9392 # indicates is we are currently downloading a program over NUS (legacy download)
9493 self ._downloading_via_nus = False
9594
@@ -188,17 +187,28 @@ async def connect(self, device: BLEDevice):
188187 """
189188 logger .info (f"Connecting to { device .name } " )
190189
191- def handle_disconnect (client : BleakClient ):
192- logger .info ("Disconnected!" )
193- self .disconnect_observable .on_next (True )
194- self .disconnect_observable .on_completed ()
195- self .connected = False
190+ if self .connection_state_observable .value != ConnectionState .DISCONNECTED :
191+ raise RuntimeError (
192+ f"attempting to connect with invalid state: { self .connection_state_observable .value } "
193+ )
196194
197- self .client = BleakClient (device , disconnected_callback = handle_disconnect )
195+ async with contextlib .AsyncExitStack () as stack :
196+ self .connection_state_observable .on_next (ConnectionState .CONNECTING )
198197
199- await self .client .connect ()
198+ stack .callback (
199+ self .connection_state_observable .on_next , ConnectionState .DISCONNECTED
200+ )
201+
202+ def handle_disconnect (_ : BleakClient ):
203+ logger .info ("Disconnected!" )
204+ self .connection_state_observable .on_next (ConnectionState .DISCONNECTED )
205+
206+ self .client = BleakClient (device , disconnected_callback = handle_disconnect )
207+
208+ await self .client .connect ()
209+
210+ stack .push_async_callback (self .disconnect )
200211
201- try :
202212 logger .info ("Connected successfully!" )
203213
204214 fw_version = await self .client .read_gatt_char (FW_REV_UUID )
@@ -236,17 +246,24 @@ def handle_disconnect(client: BleakClient):
236246 await self .client .start_notify (
237247 PYBRICKS_COMMAND_EVENT_UUID , self .pybricks_service_handler
238248 )
239- self .connected = True
240- except : # noqa: E722
241- self .disconnect ()
242- raise
249+
250+ self .connection_state_observable .on_next (ConnectionState .CONNECTED )
251+
252+ # don't unwind on success
253+ stack .pop_all ()
243254
244255 async def disconnect (self ):
245- if self .connected :
246- logger .info ("Disconnecting..." )
256+ logger .info ("Disconnecting..." )
257+
258+ if self .connection_state_observable .value == ConnectionState .CONNECTED :
259+ self .connection_state_observable .on_next (ConnectionState .DISCONNECTING )
247260 await self .client .disconnect ()
261+ # ConnectionState.DISCONNECTED should be set by disconnect callback
262+ assert (
263+ self .connection_state_observable .value == ConnectionState .DISCONNECTED
264+ )
248265 else :
249- logger .debug ("already disconnected " )
266+ logger .debug ("skipping disconnect because not connected " )
250267
251268 async def race_disconnect (self , awaitable : Awaitable [T ]) -> T :
252269 """
@@ -273,7 +290,11 @@ async def race_disconnect(self, awaitable: Awaitable[T]) -> T:
273290 disconnect_event = asyncio .Event ()
274291 disconnect_task = asyncio .ensure_future (disconnect_event .wait ())
275292
276- with self .disconnect_observable .subscribe (lambda _ : disconnect_event .set ()):
293+ def handle_disconnect (state : ConnectionState ):
294+ if state == ConnectionState .DISCONNECTED :
295+ disconnect_event .set ()
296+
297+ with self .connection_state_observable .subscribe (handle_disconnect ):
277298 done , pending = await asyncio .wait (
278299 {awaitable_task , disconnect_task },
279300 return_when = asyncio .FIRST_COMPLETED ,
@@ -301,7 +322,7 @@ async def run(
301322 wait: If true, wait for the user program to stop before returning.
302323 print_output: If true, echo stdout of the hub to ``sys.stdout``.
303324 """
304- if not self .connected :
325+ if self .connection_state_observable . value != ConnectionState . CONNECTED :
305326 raise RuntimeError ("not connected" )
306327
307328 # Reset output buffer
0 commit comments