Skip to content

Commit a4db08e

Browse files
committed
Move BLE logic into new class
Moves all logic specific to BLE connections into a new subclass of PybricksHub. Another subclass will be added later to handle USB connections. Code to retrieve firmware version, hub capabilities, etc. is moved into the connect step to better abstract this for any connection medium.
1 parent 2bb33ad commit a4db08e

File tree

2 files changed

+88
-65
lines changed

2 files changed

+88
-65
lines changed

pybricksdev/cli/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ async def run(self, args: argparse.Namespace):
174174
from ..ble import find_device
175175
from ..connections.ev3dev import EV3Connection
176176
from ..connections.lego import REPLHub
177-
from ..connections.pybricks import PybricksHub
177+
from ..connections.pybricks import PybricksHubBLE
178178

179179
# Pick the right connection
180180
if args.conntype == "ssh":
@@ -189,7 +189,7 @@ async def run(self, args: argparse.Namespace):
189189
# It is a Pybricks Hub with BLE. Device name or address is given.
190190
print(f"Searching for {args.name or 'any hub with Pybricks service'}...")
191191
device_or_address = await find_device(args.name)
192-
hub = PybricksHub(device_or_address)
192+
hub = PybricksHubBLE(device_or_address)
193193

194194
elif args.conntype == "usb":
195195
hub = REPLHub()

pybricksdev/connections/pybricks.py

Lines changed: 86 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from reactivex.subject import BehaviorSubject, Subject
1818
from tqdm.auto import tqdm
1919
from tqdm.contrib.logging import logging_redirect_tqdm
20+
from typing import Callable
2021

2122
from ..ble.lwp3.bytecodes import HubKind
2223
from ..ble.nus import NUS_RX_UUID, NUS_TX_UUID
@@ -78,7 +79,7 @@ class PybricksHub:
7879
has not been connected yet or the connected hub has Pybricks profile < v1.2.0.
7980
"""
8081

81-
def __init__(self, device: BLEDevice):
82+
def __init__(self):
8283
self.connection_state_observable = BehaviorSubject(ConnectionState.DISCONNECTED)
8384
self.status_observable = BehaviorSubject(StatusFlag(0))
8485
self._stdout_subject = Subject()
@@ -120,11 +121,6 @@ def __init__(self, device: BLEDevice):
120121
# File handle for logging
121122
self.log_file = None
122123

123-
def handle_disconnect(_: BleakClient):
124-
self._handle_disconnect()
125-
126-
self.client = BleakClient(device, disconnected_callback=handle_disconnect)
127-
128124
@property
129125
def stdout_observable(self) -> Observable[bytes]:
130126
"""
@@ -237,16 +233,6 @@ def _handle_disconnect(self):
237233
self.connection_state_observable.on_next(ConnectionState.DISCONNECTED)
238234

239235
async def connect(self):
240-
"""Connects to a device that was discovered with :meth:`pybricksdev.ble.find_device`
241-
242-
Raises:
243-
BleakError: if connecting failed (or old firmware without Device
244-
Information Service)
245-
RuntimeError: if Pybricks Protocol version is not supported
246-
"""
247-
# TODO: Fix this
248-
# logger.info(f"Connecting to {device.name}")
249-
250236
if self.connection_state_observable.value != ConnectionState.DISCONNECTED:
251237
raise RuntimeError(
252238
f"attempting to connect with invalid state: {self.connection_state_observable.value}"
@@ -259,50 +245,12 @@ async def connect(self):
259245
self.connection_state_observable.on_next, ConnectionState.DISCONNECTED
260246
)
261247

262-
await self.client.connect()
248+
await self._client_connect()
263249

264250
stack.push_async_callback(self.disconnect)
265251

266-
logger.info("Connected successfully!")
267-
268-
fw_version = await self.client.read_gatt_char(FW_REV_UUID)
269-
self.fw_version = Version(fw_version.decode())
270-
271-
protocol_version = await self.client.read_gatt_char(SW_REV_UUID)
272-
protocol_version = semver.VersionInfo.parse(protocol_version.decode())
273-
274-
if (
275-
protocol_version < PYBRICKS_PROTOCOL_VERSION
276-
or protocol_version >= PYBRICKS_PROTOCOL_VERSION.bump_major()
277-
):
278-
raise RuntimeError(
279-
f"Unsupported Pybricks protocol version: {protocol_version}"
280-
)
281-
282-
pnp_id = await self.client.read_gatt_char(PNP_ID_UUID)
283-
_, _, self.hub_kind, self.hub_variant = unpack_pnp_id(pnp_id)
284-
285-
if protocol_version >= "1.2.0":
286-
caps = await self.client.read_gatt_char(PYBRICKS_HUB_CAPABILITIES_UUID)
287-
(
288-
self._max_write_size,
289-
self._capability_flags,
290-
self._max_user_program_size,
291-
) = unpack_hub_capabilities(caps)
292-
else:
293-
# HACK: prior to profile v1.2.0 isn't a proper way to get the
294-
# MPY ABI version from hub so we use heuristics on the firmware version
295-
self._mpy_abi_version = (
296-
6 if self.fw_version >= Version("3.2.0b2") else 5
297-
)
298-
299-
if protocol_version < "1.3.0":
300-
self._legacy_stdio = True
301-
302-
await self.client.start_notify(NUS_TX_UUID, self._nus_handler)
303-
await self.client.start_notify(
304-
PYBRICKS_COMMAND_EVENT_UUID, self._pybricks_service_handler
305-
)
252+
await self.start_notify(NUS_TX_UUID, self._nus_handler)
253+
await self.start_notify(PYBRICKS_COMMAND_EVENT_UUID, self._pybricks_service_handler)
306254

307255
self.connection_state_observable.on_next(ConnectionState.CONNECTED)
308256

@@ -314,7 +262,7 @@ async def disconnect(self):
314262

315263
if self.connection_state_observable.value == ConnectionState.CONNECTED:
316264
self.connection_state_observable.on_next(ConnectionState.DISCONNECTING)
317-
await self.client.disconnect()
265+
await self._client_disconnect()
318266
# ConnectionState.DISCONNECTED should be set by disconnect callback
319267
assert (
320268
self.connection_state_observable.value == ConnectionState.DISCONNECTED
@@ -453,7 +401,7 @@ async def download_user_program(self, program: bytes) -> None:
453401
)
454402

455403
# clear user program meta so hub doesn't try to run invalid program
456-
await self.client.write_gatt_char(
404+
await self.write_gatt_char(
457405
PYBRICKS_COMMAND_EVENT_UUID,
458406
struct.pack("<BI", Command.WRITE_USER_PROGRAM_META, 0),
459407
response=True,
@@ -467,7 +415,7 @@ async def download_user_program(self, program: bytes) -> None:
467415
total=len(program), unit="B", unit_scale=True
468416
) as pbar:
469417
for i, c in enumerate(chunk(program, payload_size)):
470-
await self.client.write_gatt_char(
418+
await self.write_gatt_char(
471419
PYBRICKS_COMMAND_EVENT_UUID,
472420
struct.pack(
473421
f"<BI{len(c)}s",
@@ -480,7 +428,7 @@ async def download_user_program(self, program: bytes) -> None:
480428
pbar.update(len(c))
481429

482430
# set the metadata to notify that writing was successful
483-
await self.client.write_gatt_char(
431+
await self.write_gatt_char(
484432
PYBRICKS_COMMAND_EVENT_UUID,
485433
struct.pack("<BI", Command.WRITE_USER_PROGRAM_META, len(program)),
486434
response=True,
@@ -492,7 +440,7 @@ async def start_user_program(self) -> None:
492440
493441
Requires hub with Pybricks Profile >= v1.2.0.
494442
"""
495-
await self.client.write_gatt_char(
443+
await self.write_gatt_char(
496444
PYBRICKS_COMMAND_EVENT_UUID,
497445
struct.pack("<B", Command.START_USER_PROGRAM),
498446
response=True,
@@ -502,7 +450,7 @@ async def stop_user_program(self) -> None:
502450
"""
503451
Stops the user program on the hub if it is running.
504452
"""
505-
await self.client.write_gatt_char(
453+
await self.write_gatt_char(
506454
PYBRICKS_COMMAND_EVENT_UUID,
507455
struct.pack("<B", Command.STOP_USER_PROGRAM),
508456
response=True,
@@ -680,3 +628,78 @@ async def _wait_for_user_program_stop(self):
680628
# the user program running status flag
681629
# https://github.com/pybricks/support/issues/305
682630
await asyncio.sleep(0.3)
631+
632+
class PybricksHubBLE(PybricksHub):
633+
_device: BLEDevice
634+
_client: BleakClient
635+
636+
def __init__(self, device: BLEDevice):
637+
super().__init__()
638+
639+
self._device = device
640+
641+
def handle_disconnect(_: BleakClient):
642+
self._handle_disconnect()
643+
644+
self._client = BleakClient(self._device, disconnected_callback=handle_disconnect)
645+
646+
async def _client_connect(self) -> bool:
647+
"""Connects to a device that was discovered with :meth:`pybricksdev.ble.find_device`
648+
649+
Raises:
650+
BleakError: if connecting failed (or old firmware without Device
651+
Information Service)
652+
RuntimeError: if Pybricks Protocol version is not supported
653+
"""
654+
655+
logger.info(f"Connecting to {self._device.name}")
656+
await self._client.connect()
657+
logger.info("Connected successfully!")
658+
659+
fw_version = await self.read_gatt_char(FW_REV_UUID)
660+
self.fw_version = Version(fw_version.decode())
661+
662+
protocol_version = await self.read_gatt_char(SW_REV_UUID)
663+
protocol_version = semver.VersionInfo.parse(protocol_version.decode())
664+
665+
if (
666+
protocol_version < PYBRICKS_PROTOCOL_VERSION
667+
or protocol_version >= PYBRICKS_PROTOCOL_VERSION.bump_major()
668+
):
669+
raise RuntimeError(
670+
f"Unsupported Pybricks protocol version: {protocol_version}"
671+
)
672+
673+
pnp_id = await self.read_gatt_char(PNP_ID_UUID)
674+
_, _, self.hub_kind, self.hub_variant = unpack_pnp_id(pnp_id)
675+
676+
if protocol_version >= "1.2.0":
677+
caps = await self.read_gatt_char(PYBRICKS_HUB_CAPABILITIES_UUID)
678+
(
679+
self._max_write_size,
680+
self._capability_flags,
681+
self._max_user_program_size,
682+
) = unpack_hub_capabilities(caps)
683+
else:
684+
# HACK: prior to profile v1.2.0 isn't a proper way to get the
685+
# MPY ABI version from hub so we use heuristics on the firmware version
686+
self._mpy_abi_version = (
687+
6 if self.fw_version >= Version("3.2.0b2") else 5
688+
)
689+
690+
if protocol_version < "1.3.0":
691+
self._legacy_stdio = True
692+
693+
return True
694+
695+
async def _client_disconnect(self) -> bool:
696+
return await self._client.disconnect()
697+
698+
async def read_gatt_char(self, uuid: str) -> bytearray:
699+
return await self._client.read_gatt_char(uuid)
700+
701+
async def write_gatt_char(self, uuid: str, data, response: bool) -> None:
702+
return await self._client.write_gatt_char(uuid, data, response)
703+
704+
async def start_notify(self, uuid: str, callback: Callable) -> None:
705+
return await self._client.start_notify(uuid, callback)

0 commit comments

Comments
 (0)