Skip to content

Commit 540e6d3

Browse files
committed
Add support for USB connections
Adds a new transport to manage USB connections. Signed-off-by: Nate Karstens <[email protected]>
1 parent f8ff409 commit 540e6d3

File tree

2 files changed

+152
-2
lines changed

2 files changed

+152
-2
lines changed

pybricksdev/ble/pybricks.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,12 @@ def _standard_uuid(short: int) -> str:
328328
.. availability:: Since Pybricks protocol v1.0.0.
329329
"""
330330

331+
DEVICE_NAME_UUID = _standard_uuid(0x2A00)
332+
"""Standard Device Name UUID
333+
334+
.. availability:: Since Pybricks protocol v1.0.0.
335+
"""
336+
331337
FW_REV_UUID = _standard_uuid(0x2A26)
332338
"""Standard Firmware Revision String characteristic UUID
333339

pybricksdev/connections/pybricks.py

Lines changed: 146 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import os
99
import struct
1010
from typing import Awaitable, Callable, List, Optional, Tuple, TypeVar
11+
from uuid import UUID
1112

1213
import reactivex.operators as op
1314
import semver
@@ -19,9 +20,15 @@
1920
from tqdm.auto import tqdm
2021
from tqdm.contrib.logging import logging_redirect_tqdm
2122

23+
from usb.control import get_descriptor
24+
from usb.core import Device as USBDevice
25+
from usb.core import Endpoint, USBTimeoutError
26+
from usb.util import ENDPOINT_IN, ENDPOINT_OUT, endpoint_direction, find_descriptor
27+
2228
from ..ble.lwp3.bytecodes import HubKind
2329
from ..ble.nus import NUS_RX_UUID, NUS_TX_UUID
2430
from ..ble.pybricks import (
31+
DEVICE_NAME_UUID,
2532
FW_REV_UUID,
2633
PNP_ID_UUID,
2734
PYBRICKS_COMMAND_EVENT_UUID,
@@ -142,6 +149,140 @@ def handler(_, data):
142149
return await self._client.start_notify(PYBRICKS_COMMAND_EVENT_UUID, handler)
143150

144151

152+
class _USBTransport(_Transport):
153+
_device: USBDevice
154+
_disconnected_callback: Callable
155+
_ep_in: Endpoint
156+
_ep_out: Endpoint
157+
_notify_callbacks = {}
158+
_monitor_task: asyncio.Task
159+
160+
def __init__(self, device: USBDevice):
161+
self._device = device
162+
163+
async def connect(self, disconnected_callback: Callable) -> None:
164+
self._disconnected_callback = disconnected_callback
165+
self._device.set_configuration()
166+
167+
# Save input and output endpoints
168+
cfg = self._device.get_active_configuration()
169+
intf = cfg[(0, 0)]
170+
self._ep_in = find_descriptor(
171+
intf,
172+
custom_match=lambda e: endpoint_direction(e.bEndpointAddress)
173+
== ENDPOINT_IN,
174+
)
175+
self._ep_out = find_descriptor(
176+
intf,
177+
custom_match=lambda e: endpoint_direction(e.bEndpointAddress)
178+
== ENDPOINT_OUT,
179+
)
180+
181+
# Set write size to endpoint packet size minus length of UUID
182+
self._max_write_size = self._ep_out.wMaxPacketSize - 16
183+
184+
# Get length of BOS descriptor
185+
bos_descriptor = get_descriptor(self._device, 5, 0x0F, 0)
186+
(ofst, _, bos_len, _) = struct.unpack("<BBHB", bos_descriptor)
187+
188+
# Get full BOS descriptor
189+
bos_descriptor = get_descriptor(self._device, bos_len, 0x0F, 0)
190+
191+
while ofst < bos_len:
192+
(len, desc_type, cap_type) = struct.unpack_from(
193+
"<BBB", bos_descriptor, offset=ofst
194+
)
195+
196+
if desc_type != 0x10:
197+
logger.error("Expected Device Capability descriptor")
198+
exit(1)
199+
200+
# Look for platform descriptors
201+
if cap_type == 0x05:
202+
uuid_bytes = bos_descriptor[ofst + 4 : ofst + 4 + 16]
203+
uuid_str = str(UUID(bytes_le=bytes(uuid_bytes)))
204+
205+
if uuid_str == DEVICE_NAME_UUID:
206+
self._device_name = bytearray(
207+
bos_descriptor[ofst + 20 : ofst + len]
208+
).decode()
209+
print("Connected to hub '" + self._device_name + "'")
210+
211+
elif uuid_str == FW_REV_UUID:
212+
fw_version = bytearray(bos_descriptor[ofst + 20 : ofst + len])
213+
self._fw_version = Version(fw_version.decode())
214+
215+
elif uuid_str == SW_REV_UUID:
216+
protocol_version = bytearray(bos_descriptor[ofst + 20 : ofst + len])
217+
self._protocol_version = semver.VersionInfo.parse(
218+
protocol_version.decode()
219+
)
220+
221+
elif uuid_str == PYBRICKS_HUB_CAPABILITIES_UUID:
222+
caps = bytearray(bos_descriptor[ofst + 20 : ofst + len])
223+
(
224+
_,
225+
self._capability_flags,
226+
self._max_user_program_size,
227+
) = unpack_hub_capabilities(caps)
228+
229+
ofst += len
230+
231+
self._monitor_task = asyncio.create_task(self._monitor_usb())
232+
233+
async def disconnect(self) -> None:
234+
self._monitor_task.cancel()
235+
self._disconnected_callback()
236+
237+
async def get_firmware_version(self) -> Version:
238+
return self._fw_version
239+
240+
async def get_protocol_version(self) -> Version:
241+
return self._protocol_version
242+
243+
async def get_hub_type(self) -> Tuple[HubKind, int]:
244+
return (HubKind.TECHNIC_LARGE, 0)
245+
246+
async def get_hub_capabilities(self) -> Tuple[int, HubCapabilityFlag, int]:
247+
return (
248+
self._max_write_size,
249+
self._capability_flags,
250+
self._max_user_program_size,
251+
)
252+
253+
async def send_command(self, command: bytes) -> None:
254+
self._ep_out.write(UUID(PYBRICKS_COMMAND_EVENT_UUID).bytes_le + command)
255+
256+
async def set_nus_handler(self, callback: Callable) -> None:
257+
pass
258+
259+
async def set_service_handler(self, callback: Callable) -> None:
260+
self._notify_callbacks[PYBRICKS_COMMAND_EVENT_UUID] = callback
261+
262+
async def _monitor_usb(self):
263+
loop = asyncio.get_running_loop()
264+
265+
while True:
266+
msg = await loop.run_in_executor(None, self._read_usb)
267+
268+
if msg is None:
269+
continue
270+
271+
if len(msg) > 16:
272+
uuid = str(UUID(bytes_le=bytes(msg[0:16])))
273+
if uuid in self._notify_callbacks:
274+
callback = self._notify_callbacks[uuid]
275+
if callback:
276+
callback(bytes(msg[16:]))
277+
278+
def _read_usb(self):
279+
try:
280+
msg = self._ep_in.read(self._ep_in.wMaxPacketSize)
281+
return msg
282+
except USBTimeoutError:
283+
return None
284+
285+
145286
class PybricksHub:
146287
EOL = b"\r\n" # MicroPython EOL
147288

@@ -178,7 +319,7 @@ class PybricksHub:
178319

179320
_transport: _Transport
180321

181-
def __init__(self, device: BLEDevice):
322+
def __init__(self, device):
182323
self.connection_state_observable = BehaviorSubject(ConnectionState.DISCONNECTED)
183324
self.status_observable = BehaviorSubject(StatusFlag(0))
184325
self._stdout_subject = Subject()
@@ -220,7 +361,10 @@ def __init__(self, device: BLEDevice):
220361
# File handle for logging
221362
self.log_file = None
222363

223-
self._transport = _BLETransport(device)
364+
if isinstance(device, BLEDevice):
365+
self._transport = _BLETransport(device)
366+
elif isinstance(device, USBDevice):
367+
self._transport = _USBTransport(device)
224368

225369
@property
226370
def stdout_observable(self) -> Observable[bytes]:

0 commit comments

Comments
 (0)