66import logging
77import os
88import struct
9- from typing import Awaitable , List , Optional , TypeVar
9+ from typing import Awaitable , Callable , List , Optional , TypeVar
1010
1111import reactivex .operators as op
1212import semver
@@ -78,7 +78,7 @@ class PybricksHub:
7878 has not been connected yet or the connected hub has Pybricks profile < v1.2.0.
7979 """
8080
81- def __init__ (self , device : BLEDevice ):
81+ def __init__ (self ):
8282 self .connection_state_observable = BehaviorSubject (ConnectionState .DISCONNECTED )
8383 self .status_observable = BehaviorSubject (StatusFlag (0 ))
8484 self ._stdout_subject = Subject ()
@@ -120,11 +120,6 @@ def __init__(self, device: BLEDevice):
120120 # File handle for logging
121121 self .log_file = None
122122
123- def handle_disconnect (_ : BleakClient ):
124- self ._handle_disconnect ()
125-
126- self .client = BleakClient (device , disconnected_callback = handle_disconnect )
127-
128123 @property
129124 def stdout_observable (self ) -> Observable [bytes ]:
130125 """
@@ -237,16 +232,6 @@ def _handle_disconnect(self):
237232 self .connection_state_observable .on_next (ConnectionState .DISCONNECTED )
238233
239234 async def connect (self ):
240- """Connects to a device that was discovered with :meth:`pybricksdev.ble.find_device`
241-
242- Raises:
243- BleakError: if connecting failed (or old firmware without Device
244- Information Service)
245- RuntimeError: if Pybricks Protocol version is not supported
246- """
247- # TODO: Fix this
248- # logger.info(f"Connecting to {device.name}")
249-
250235 if self .connection_state_observable .value != ConnectionState .DISCONNECTED :
251236 raise RuntimeError (
252237 f"attempting to connect with invalid state: { self .connection_state_observable .value } "
@@ -259,48 +244,12 @@ async def connect(self):
259244 self .connection_state_observable .on_next , ConnectionState .DISCONNECTED
260245 )
261246
262- await self .client . connect ()
247+ await self ._client_connect ()
263248
264249 stack .push_async_callback (self .disconnect )
265250
266- logger .info ("Connected successfully!" )
267-
268- fw_version = await self .client .read_gatt_char (FW_REV_UUID )
269- self .fw_version = Version (fw_version .decode ())
270-
271- protocol_version = await self .client .read_gatt_char (SW_REV_UUID )
272- protocol_version = semver .VersionInfo .parse (protocol_version .decode ())
273-
274- if (
275- protocol_version < PYBRICKS_PROTOCOL_VERSION
276- or protocol_version >= PYBRICKS_PROTOCOL_VERSION .bump_major ()
277- ):
278- raise RuntimeError (
279- f"Unsupported Pybricks protocol version: { protocol_version } "
280- )
281-
282- pnp_id = await self .client .read_gatt_char (PNP_ID_UUID )
283- _ , _ , self .hub_kind , self .hub_variant = unpack_pnp_id (pnp_id )
284-
285- if protocol_version >= "1.2.0" :
286- caps = await self .client .read_gatt_char (PYBRICKS_HUB_CAPABILITIES_UUID )
287- (
288- self ._max_write_size ,
289- self ._capability_flags ,
290- self ._max_user_program_size ,
291- ) = unpack_hub_capabilities (caps )
292- else :
293- # HACK: prior to profile v1.2.0 isn't a proper way to get the
294- # MPY ABI version from hub so we use heuristics on the firmware version
295- self ._mpy_abi_version = (
296- 6 if self .fw_version >= Version ("3.2.0b2" ) else 5
297- )
298-
299- if protocol_version < "1.3.0" :
300- self ._legacy_stdio = True
301-
302- await self .client .start_notify (NUS_TX_UUID , self ._nus_handler )
303- await self .client .start_notify (
251+ await self .start_notify (NUS_TX_UUID , self ._nus_handler )
252+ await self .start_notify (
304253 PYBRICKS_COMMAND_EVENT_UUID , self ._pybricks_service_handler
305254 )
306255
@@ -314,7 +263,7 @@ async def disconnect(self):
314263
315264 if self .connection_state_observable .value == ConnectionState .CONNECTED :
316265 self .connection_state_observable .on_next (ConnectionState .DISCONNECTING )
317- await self .client . disconnect ()
266+ await self ._client_disconnect ()
318267 # ConnectionState.DISCONNECTED should be set by disconnect callback
319268 assert (
320269 self .connection_state_observable .value == ConnectionState .DISCONNECTED
@@ -453,7 +402,7 @@ async def download_user_program(self, program: bytes) -> None:
453402 )
454403
455404 # clear user program meta so hub doesn't try to run invalid program
456- await self .client . write_gatt_char (
405+ await self .write_gatt_char (
457406 PYBRICKS_COMMAND_EVENT_UUID ,
458407 struct .pack ("<BI" , Command .WRITE_USER_PROGRAM_META , 0 ),
459408 response = True ,
@@ -467,7 +416,7 @@ async def download_user_program(self, program: bytes) -> None:
467416 total = len (program ), unit = "B" , unit_scale = True
468417 ) as pbar :
469418 for i , c in enumerate (chunk (program , payload_size )):
470- await self .client . write_gatt_char (
419+ await self .write_gatt_char (
471420 PYBRICKS_COMMAND_EVENT_UUID ,
472421 struct .pack (
473422 f"<BI{ len (c )} s" ,
@@ -480,7 +429,7 @@ async def download_user_program(self, program: bytes) -> None:
480429 pbar .update (len (c ))
481430
482431 # set the metadata to notify that writing was successful
483- await self .client . write_gatt_char (
432+ await self .write_gatt_char (
484433 PYBRICKS_COMMAND_EVENT_UUID ,
485434 struct .pack ("<BI" , Command .WRITE_USER_PROGRAM_META , len (program )),
486435 response = True ,
@@ -492,7 +441,7 @@ async def start_user_program(self) -> None:
492441
493442 Requires hub with Pybricks Profile >= v1.2.0.
494443 """
495- await self .client . write_gatt_char (
444+ await self .write_gatt_char (
496445 PYBRICKS_COMMAND_EVENT_UUID ,
497446 struct .pack ("<B" , Command .START_USER_PROGRAM ),
498447 response = True ,
@@ -502,7 +451,7 @@ async def stop_user_program(self) -> None:
502451 """
503452 Stops the user program on the hub if it is running.
504453 """
505- await self .client . write_gatt_char (
454+ await self .write_gatt_char (
506455 PYBRICKS_COMMAND_EVENT_UUID ,
507456 struct .pack ("<B" , Command .STOP_USER_PROGRAM ),
508457 response = True ,
@@ -680,3 +629,79 @@ async def _wait_for_user_program_stop(self):
680629 # the user program running status flag
681630 # https://github.com/pybricks/support/issues/305
682631 await asyncio .sleep (0.3 )
632+
633+
634+ class PybricksHubBLE (PybricksHub ):
635+ _device : BLEDevice
636+ _client : BleakClient
637+
638+ def __init__ (self , device : BLEDevice ):
639+ super ().__init__ ()
640+
641+ self ._device = device
642+
643+ def handle_disconnect (_ : BleakClient ):
644+ self ._handle_disconnect ()
645+
646+ self ._client = BleakClient (
647+ self ._device , disconnected_callback = handle_disconnect
648+ )
649+
650+ async def _client_connect (self ) -> bool :
651+ """Connects to a device that was discovered with :meth:`pybricksdev.ble.find_device`
652+
653+ Raises:
654+ BleakError: if connecting failed (or old firmware without Device
655+ Information Service)
656+ RuntimeError: if Pybricks Protocol version is not supported
657+ """
658+
659+ logger .info (f"Connecting to { self ._device .name } " )
660+ await self ._client .connect ()
661+ logger .info ("Connected successfully!" )
662+
663+ fw_version = await self .read_gatt_char (FW_REV_UUID )
664+ self .fw_version = Version (fw_version .decode ())
665+
666+ protocol_version = await self .read_gatt_char (SW_REV_UUID )
667+ protocol_version = semver .VersionInfo .parse (protocol_version .decode ())
668+
669+ if (
670+ protocol_version < PYBRICKS_PROTOCOL_VERSION
671+ or protocol_version >= PYBRICKS_PROTOCOL_VERSION .bump_major ()
672+ ):
673+ raise RuntimeError (
674+ f"Unsupported Pybricks protocol version: { protocol_version } "
675+ )
676+
677+ pnp_id = await self .read_gatt_char (PNP_ID_UUID )
678+ _ , _ , self .hub_kind , self .hub_variant = unpack_pnp_id (pnp_id )
679+
680+ if protocol_version >= "1.2.0" :
681+ caps = await self .read_gatt_char (PYBRICKS_HUB_CAPABILITIES_UUID )
682+ (
683+ self ._max_write_size ,
684+ self ._capability_flags ,
685+ self ._max_user_program_size ,
686+ ) = unpack_hub_capabilities (caps )
687+ else :
688+ # HACK: prior to profile v1.2.0 isn't a proper way to get the
689+ # MPY ABI version from hub so we use heuristics on the firmware version
690+ self ._mpy_abi_version = 6 if self .fw_version >= Version ("3.2.0b2" ) else 5
691+
692+ if protocol_version < "1.3.0" :
693+ self ._legacy_stdio = True
694+
695+ return True
696+
697+ async def _client_disconnect (self ) -> bool :
698+ return await self ._client .disconnect ()
699+
700+ async def read_gatt_char (self , uuid : str ) -> bytearray :
701+ return await self ._client .read_gatt_char (uuid )
702+
703+ async def write_gatt_char (self , uuid : str , data , response : bool ) -> None :
704+ return await self ._client .write_gatt_char (uuid , data , response )
705+
706+ async def start_notify (self , uuid : str , callback : Callable ) -> None :
707+ return await self ._client .start_notify (uuid , callback )
0 commit comments