|
8 | 8 | import os |
9 | 9 | import struct |
10 | 10 | from typing import Awaitable, Callable, List, Optional, Tuple, TypeVar |
| 11 | +from uuid import UUID |
11 | 12 |
|
12 | 13 | import reactivex.operators as op |
13 | 14 | import semver |
|
19 | 20 | from tqdm.auto import tqdm |
20 | 21 | from tqdm.contrib.logging import logging_redirect_tqdm |
21 | 22 |
|
| 23 | +from usb.control import get_descriptor |
| 24 | +from usb.core import Device as USBDevice |
| 25 | +from usb.core import Endpoint, USBTimeoutError |
| 26 | +from usb.util import ENDPOINT_IN, ENDPOINT_OUT, endpoint_direction, find_descriptor |
| 27 | + |
22 | 28 | from ..ble.lwp3.bytecodes import HubKind |
23 | 29 | from ..ble.nus import NUS_RX_UUID, NUS_TX_UUID |
24 | 30 | from ..ble.pybricks import ( |
| 31 | + DEVICE_NAME_UUID, |
25 | 32 | FW_REV_UUID, |
26 | 33 | PNP_ID_UUID, |
27 | 34 | PYBRICKS_COMMAND_EVENT_UUID, |
@@ -142,6 +149,140 @@ def handler(_, data): |
142 | 149 | return await self._client.start_notify(PYBRICKS_COMMAND_EVENT_UUID, handler) |
143 | 150 |
|
144 | 151 |
|
| 152 | +class _USBTransport(_Transport): |
| 153 | + _device: USBDevice |
| 154 | + _disconnected_callback: Callable |
| 155 | + _ep_in: Endpoint |
| 156 | + _ep_out: Endpoint |
| 157 | + _notify_callbacks = {} |
| 158 | + _monitor_task: asyncio.Task |
| 159 | + |
| 160 | + def __init__(self, device: USBDevice): |
| 161 | + self._device = device |
| 162 | + |
| 163 | + async def connect(self, disconnected_callback: Callable) -> None: |
| 164 | + self._disconnected_callback = disconnected_callback |
| 165 | + self._device.set_configuration() |
| 166 | + |
| 167 | + # Save input and output endpoints |
| 168 | + cfg = self._device.get_active_configuration() |
| 169 | + intf = cfg[(0, 0)] |
| 170 | + self._ep_in = find_descriptor( |
| 171 | + intf, |
| 172 | + custom_match=lambda e: endpoint_direction(e.bEndpointAddress) |
| 173 | + == ENDPOINT_IN, |
| 174 | + ) |
| 175 | + self._ep_out = find_descriptor( |
| 176 | + intf, |
| 177 | + custom_match=lambda e: endpoint_direction(e.bEndpointAddress) |
| 178 | + == ENDPOINT_OUT, |
| 179 | + ) |
| 180 | + |
| 181 | + # Set write size to endpoint packet size minus length of UUID |
| 182 | + self._max_write_size = self._ep_out.wMaxPacketSize - 16 |
| 183 | + |
| 184 | + # Get length of BOS descriptor |
| 185 | + bos_descriptor = get_descriptor(self._device, 5, 0x0F, 0) |
| 186 | + (ofst, _, bos_len, _) = struct.unpack("<BBHB", bos_descriptor) |
| 187 | + |
| 188 | + # Get full BOS descriptor |
| 189 | + bos_descriptor = get_descriptor(self._device, bos_len, 0x0F, 0) |
| 190 | + |
| 191 | + while ofst < bos_len: |
| 192 | + (len, desc_type, cap_type) = struct.unpack_from( |
| 193 | + "<BBB", bos_descriptor, offset=ofst |
| 194 | + ) |
| 195 | + |
| 196 | + if desc_type != 0x10: |
| 197 | + logger.error("Expected Device Capability descriptor") |
| 198 | + exit(1) |
| 199 | + |
| 200 | + # Look for platform descriptors |
| 201 | + if cap_type == 0x05: |
| 202 | + uuid_bytes = bos_descriptor[ofst + 4 : ofst + 4 + 16] |
| 203 | + uuid_str = str(UUID(bytes_le=bytes(uuid_bytes))) |
| 204 | + |
| 205 | + if uuid_str == DEVICE_NAME_UUID: |
| 206 | + self._device_name = bytearray( |
| 207 | + bos_descriptor[ofst + 20 : ofst + len] |
| 208 | + ).decode() |
| 209 | + print("Connected to hub '" + self._device_name + "'") |
| 210 | + |
| 211 | + elif uuid_str == FW_REV_UUID: |
| 212 | + fw_version = bytearray(bos_descriptor[ofst + 20 : ofst + len]) |
| 213 | + self._fw_version = Version(fw_version.decode()) |
| 214 | + |
| 215 | + elif uuid_str == SW_REV_UUID: |
| 216 | + protocol_version = bytearray(bos_descriptor[ofst + 20 : ofst + len]) |
| 217 | + self._protocol_version = semver.VersionInfo.parse( |
| 218 | + protocol_version.decode() |
| 219 | + ) |
| 220 | + |
| 221 | + elif uuid_str == PYBRICKS_HUB_CAPABILITIES_UUID: |
| 222 | + caps = bytearray(bos_descriptor[ofst + 20 : ofst + len]) |
| 223 | + ( |
| 224 | + _, |
| 225 | + self._capability_flags, |
| 226 | + self._max_user_program_size, |
| 227 | + ) = unpack_hub_capabilities(caps) |
| 228 | + |
| 229 | + ofst += len |
| 230 | + |
| 231 | + self._monitor_task = asyncio.create_task(self._monitor_usb()) |
| 232 | + |
| 233 | + async def disconnect(self) -> None: |
| 234 | + self._monitor_task.cancel() |
| 235 | + self._disconnected_callback() |
| 236 | + |
| 237 | + async def get_firmware_version(self) -> Version: |
| 238 | + return self._fw_version |
| 239 | + |
| 240 | + async def get_protocol_version(self) -> Version: |
| 241 | + return self._protocol_version |
| 242 | + |
| 243 | + async def get_hub_type(self) -> Tuple[HubKind, int]: |
| 244 | + return (HubKind.TECHNIC_LARGE, 0) |
| 245 | + |
| 246 | + async def get_hub_capabilities(self) -> Tuple[int, HubCapabilityFlag, int]: |
| 247 | + return ( |
| 248 | + self._max_write_size, |
| 249 | + self._capability_flags, |
| 250 | + self._max_user_program_size, |
| 251 | + ) |
| 252 | + |
| 253 | + async def send_command(self, command: bytes) -> None: |
| 254 | + self._ep_out.write(UUID(PYBRICKS_COMMAND_EVENT_UUID).bytes_le + command) |
| 255 | + |
| 256 | + async def set_nus_handler(self, callback: Callable) -> None: |
| 257 | + pass |
| 258 | + |
| 259 | + async def set_service_handler(self, callback: Callable) -> None: |
| 260 | + self._notify_callbacks[PYBRICKS_COMMAND_EVENT_UUID] = callback |
| 261 | + |
| 262 | + async def _monitor_usb(self): |
| 263 | + loop = asyncio.get_running_loop() |
| 264 | + |
| 265 | + while True: |
| 266 | + msg = await loop.run_in_executor(None, self._read_usb) |
| 267 | + |
| 268 | + if msg is None: |
| 269 | + continue |
| 270 | + |
| 271 | + if len(msg) > 16: |
| 272 | + uuid = str(UUID(bytes_le=bytes(msg[0:16]))) |
| 273 | + if uuid in self._notify_callbacks: |
| 274 | + callback = self._notify_callbacks[uuid] |
| 275 | + if callback: |
| 276 | + callback(bytes(msg[16:])) |
| 277 | + |
| 278 | + def _read_usb(self): |
| 279 | + try: |
| 280 | + msg = self._ep_in.read(self._ep_in.wMaxPacketSize) |
| 281 | + return msg |
| 282 | + except USBTimeoutError: |
| 283 | + return None |
| 284 | + |
| 285 | + |
145 | 286 | class PybricksHub: |
146 | 287 | EOL = b"\r\n" # MicroPython EOL |
147 | 288 |
|
@@ -178,7 +319,7 @@ class PybricksHub: |
178 | 319 |
|
179 | 320 | _transport: _Transport |
180 | 321 |
|
181 | | - def __init__(self, device: BLEDevice): |
| 322 | + def __init__(self, device): |
182 | 323 | self.connection_state_observable = BehaviorSubject(ConnectionState.DISCONNECTED) |
183 | 324 | self.status_observable = BehaviorSubject(StatusFlag(0)) |
184 | 325 | self._stdout_subject = Subject() |
@@ -220,7 +361,10 @@ def __init__(self, device: BLEDevice): |
220 | 361 | # File handle for logging |
221 | 362 | self.log_file = None |
222 | 363 |
|
223 | | - self._transport = _BLETransport(device) |
| 364 | + if isinstance(device, BLEDevice): |
| 365 | + self._transport = _BLETransport(device) |
| 366 | + elif isinstance(device, USBDevice): |
| 367 | + self._transport = _USBTransport(device) |
224 | 368 |
|
225 | 369 | @property |
226 | 370 | def stdout_observable(self) -> Observable[bytes]: |
|
0 commit comments