Skip to content

Commit c705276

Browse files
committed
Add support for USB connections
Adds a new subclass of PybricksHub that manages USB connections. Signed-off-by: Nate Karstens <[email protected]>
1 parent eb28049 commit c705276

File tree

3 files changed

+149
-4
lines changed

3 files changed

+149
-4
lines changed

pybricksdev/ble/pybricks.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,12 @@ def _standard_uuid(short: int) -> str:
328328
.. availability:: Since Pybricks protocol v1.0.0.
329329
"""
330330

331+
DEVICE_NAME_UUID = _standard_uuid(0x2A00)
332+
"""Standard Device Name UUID
333+
334+
.. availability:: Since Pybricks protocol v1.0.0.
335+
"""
336+
331337
FW_REV_UUID = _standard_uuid(0x2A26)
332338
"""Standard Firmware Revision String characteristic UUID
333339

pybricksdev/cli/__init__.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,12 @@ def add_parser(self, subparsers: argparse._SubParsersAction):
171171
)
172172

173173
async def run(self, args: argparse.Namespace):
174-
from ..ble import find_device
174+
from usb.core import find as find_usb
175+
176+
from ..ble import find_device as find_ble
175177
from ..connections.ev3dev import EV3Connection
176178
from ..connections.lego import REPLHub
177-
from ..connections.pybricks import PybricksHubBLE
179+
from ..connections.pybricks import PybricksHubBLE, PybricksHubUSB
178180

179181
# Pick the right connection
180182
if args.conntype == "ssh":
@@ -185,14 +187,28 @@ async def run(self, args: argparse.Namespace):
185187

186188
device_or_address = socket.gethostbyname(args.name)
187189
hub = EV3Connection(device_or_address)
190+
188191
elif args.conntype == "ble":
189192
# It is a Pybricks Hub with BLE. Device name or address is given.
190193
print(f"Searching for {args.name or 'any hub with Pybricks service'}...")
191-
device_or_address = await find_device(args.name)
194+
device_or_address = await find_ble(args.name)
192195
hub = PybricksHubBLE(device_or_address)
193196

194197
elif args.conntype == "usb":
195-
hub = REPLHub()
198+
199+
def is_pybricks_usb(dev):
200+
return (
201+
(dev.idVendor == 0x0694)
202+
and ((dev.idProduct == 0x0009) or (dev.idProduct == 0x0011))
203+
and dev.product.endswith("Pybricks")
204+
)
205+
206+
device_or_address = find_usb(custom_match=is_pybricks_usb)
207+
208+
if device_or_address is not None:
209+
hub = PybricksHubUSB(device_or_address)
210+
else:
211+
hub = REPLHub()
196212
else:
197213
raise ValueError(f"Unknown connection type: {args.conntype}")
198214

pybricksdev/connections/pybricks.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import struct
99
from typing import Awaitable, Callable, List, Optional, TypeVar
10+
from uuid import UUID
1011

1112
import reactivex.operators as op
1213
import semver
@@ -17,10 +18,15 @@
1718
from reactivex.subject import BehaviorSubject, Subject
1819
from tqdm.auto import tqdm
1920
from tqdm.contrib.logging import logging_redirect_tqdm
21+
from usb.control import get_descriptor
22+
from usb.core import Device as USBDevice
23+
from usb.core import Endpoint, USBTimeoutError
24+
from usb.util import ENDPOINT_IN, ENDPOINT_OUT, endpoint_direction, find_descriptor
2025

2126
from ..ble.lwp3.bytecodes import HubKind
2227
from ..ble.nus import NUS_RX_UUID, NUS_TX_UUID
2328
from ..ble.pybricks import (
29+
DEVICE_NAME_UUID,
2430
FW_REV_UUID,
2531
PNP_ID_UUID,
2632
PYBRICKS_COMMAND_EVENT_UUID,
@@ -705,3 +711,120 @@ async def write_gatt_char(self, uuid: str, data, response: bool) -> None:
705711

706712
async def start_notify(self, uuid: str, callback: Callable) -> None:
707713
return await self._client.start_notify(uuid, callback)
714+
715+
716+
class PybricksHubUSB(PybricksHub):
717+
_device: USBDevice
718+
_ep_in: Endpoint
719+
_ep_out: Endpoint
720+
_notify_callbacks = {}
721+
_monitor_task: asyncio.Task
722+
723+
def __init__(self, device: USBDevice):
724+
super().__init__()
725+
self._device = device
726+
727+
async def _client_connect(self) -> bool:
728+
self._device.set_configuration()
729+
730+
# Save input and output endpoints
731+
cfg = self._device.get_active_configuration()
732+
intf = cfg[(0, 0)]
733+
self._ep_in = find_descriptor(
734+
intf,
735+
custom_match=lambda e: endpoint_direction(e.bEndpointAddress)
736+
== ENDPOINT_IN,
737+
)
738+
self._ep_out = find_descriptor(
739+
intf,
740+
custom_match=lambda e: endpoint_direction(e.bEndpointAddress)
741+
== ENDPOINT_OUT,
742+
)
743+
744+
# Set write size to endpoint packet size minus length of UUID
745+
self._max_write_size = self._ep_out.wMaxPacketSize - 16
746+
747+
# Get length of BOS descriptor
748+
bos_descriptor = get_descriptor(self._device, 5, 0x0F, 0)
749+
(ofst, _, bos_len, _) = struct.unpack("<BBHB", bos_descriptor)
750+
751+
# Get full BOS descriptor
752+
bos_descriptor = get_descriptor(self._device, bos_len, 0x0F, 0)
753+
754+
while ofst < bos_len:
755+
(len, desc_type, cap_type) = struct.unpack_from(
756+
"<BBB", bos_descriptor, offset=ofst
757+
)
758+
759+
if desc_type != 0x10:
760+
logger.error("Expected Device Capability descriptor")
761+
exit(1)
762+
763+
# Look for platform descriptors
764+
if cap_type == 0x05:
765+
uuid_bytes = bos_descriptor[ofst + 4 : ofst + 4 + 16]
766+
uuid_str = str(UUID(bytes_le=bytes(uuid_bytes)))
767+
768+
if uuid_str == DEVICE_NAME_UUID:
769+
device_name = bytearray(bos_descriptor[ofst + 20 : ofst + len])
770+
print("Connected to hub '" + device_name.decode() + "'")
771+
772+
elif uuid_str == FW_REV_UUID:
773+
fw_version = bytearray(bos_descriptor[ofst + 20 : ofst + len])
774+
self.fw_version = Version(fw_version.decode())
775+
776+
elif uuid_str == SW_REV_UUID:
777+
self._protocol_version = bytearray(
778+
bos_descriptor[ofst + 20 : ofst + len]
779+
)
780+
781+
elif uuid_str == PYBRICKS_HUB_CAPABILITIES_UUID:
782+
caps = bytearray(bos_descriptor[ofst + 20 : ofst + len])
783+
(
784+
_,
785+
self._capability_flags,
786+
self._max_user_program_size,
787+
) = unpack_hub_capabilities(caps)
788+
789+
ofst += len
790+
791+
self._monitor_task = asyncio.create_task(self._monitor_usb())
792+
793+
return True
794+
795+
async def _client_disconnect(self) -> bool:
796+
self._monitor_task.cancel()
797+
self._handle_disconnect()
798+
799+
async def read_gatt_char(self, uuid: str) -> bytearray:
800+
return None
801+
802+
async def write_gatt_char(self, uuid: str, data, response: bool) -> None:
803+
self._ep_out.write(UUID(uuid).bytes_le + data)
804+
# TODO: Handle response
805+
806+
async def start_notify(self, uuid: str, callback: Callable) -> None:
807+
self._notify_callbacks[uuid] = callback
808+
809+
async def _monitor_usb(self):
810+
loop = asyncio.get_running_loop()
811+
812+
while True:
813+
msg = await loop.run_in_executor(None, self._read_usb)
814+
815+
if msg is None:
816+
continue
817+
818+
if len(msg) > 16:
819+
uuid = str(UUID(bytes_le=bytes(msg[0:16])))
820+
if uuid in self._notify_callbacks:
821+
callback = self._notify_callbacks[uuid]
822+
if callback:
823+
callback(None, bytes(msg[16:]))
824+
825+
def _read_usb(self):
826+
try:
827+
msg = self._ep_in.read(self._ep_in.wMaxPacketSize)
828+
return msg
829+
except USBTimeoutError:
830+
return None

0 commit comments

Comments
 (0)