88import os
99import struct
1010from typing import Awaitable , Callable , List , Optional , Tuple , TypeVar
11+ from uuid import UUID
1112
1213import reactivex .operators as op
1314import semver
1920from tqdm .auto import tqdm
2021from tqdm .contrib .logging import logging_redirect_tqdm
2122
23+ from usb .control import get_descriptor
24+ from usb .core import Device as USBDevice
25+ from usb .core import Endpoint , USBTimeoutError
26+ from usb .util import ENDPOINT_IN , ENDPOINT_OUT , endpoint_direction , find_descriptor
27+
2228from ..ble .lwp3 .bytecodes import HubKind
2329from ..ble .nus import NUS_RX_UUID , NUS_TX_UUID
2430from ..ble .pybricks import (
31+ DEVICE_NAME_UUID ,
2532 FW_REV_UUID ,
2633 PNP_ID_UUID ,
2734 PYBRICKS_COMMAND_EVENT_UUID ,
3845from ..compile import compile_file , compile_multi_file
3946from ..tools import chunk
4047from ..tools .checksum import xor_bytes
48+ from ..usb import LegoUsbPid
4149from . import ConnectionState
4250
4351logger = logging .getLogger (__name__ )
@@ -138,6 +146,140 @@ def handler(_, data):
138146 await self ._client .start_notify (NUS_TX_UUID , handler )
139147
140148
149+ class _USBTransport (_Transport ):
150+ _device : USBDevice
151+ _disconnected_callback : Callable
152+ _ep_in : Endpoint
153+ _ep_out : Endpoint
154+ _notify_callbacks = {}
155+ _monitor_task : asyncio .Task
156+
157+ def __init__ (self , device : USBDevice ):
158+ self ._device = device
159+
160+ async def connect (self , disconnected_callback : Callable ) -> None :
161+ self ._disconnected_callback = disconnected_callback
162+ self ._device .set_configuration ()
163+
164+ # Save input and output endpoints
165+ cfg = self ._device .get_active_configuration ()
166+ intf = cfg [(0 , 0 )]
167+ self ._ep_in = find_descriptor (
168+ intf ,
169+ custom_match = lambda e : endpoint_direction (e .bEndpointAddress )
170+ == ENDPOINT_IN ,
171+ )
172+ self ._ep_out = find_descriptor (
173+ intf ,
174+ custom_match = lambda e : endpoint_direction (e .bEndpointAddress )
175+ == ENDPOINT_OUT ,
176+ )
177+
178+ # Get length of BOS descriptor
179+ bos_descriptor = get_descriptor (self ._device , 5 , 0x0F , 0 )
180+ (ofst , bos_len ) = struct .unpack ("<BxHx" , bos_descriptor )
181+
182+ # Get full BOS descriptor
183+ bos_descriptor = get_descriptor (self ._device , bos_len , 0x0F , 0 )
184+
185+ while ofst < bos_len :
186+ (len , desc_type , cap_type ) = struct .unpack_from (
187+ "<BBB" , bos_descriptor , offset = ofst
188+ )
189+
190+ if desc_type != 0x10 :
191+ raise Exception ("Expected Device Capability descriptor" )
192+
193+ # Look for platform descriptors
194+ if cap_type == 0x05 :
195+ uuid_bytes = bos_descriptor [ofst + 4 : ofst + 4 + 16 ]
196+ uuid_str = str (UUID (bytes_le = bytes (uuid_bytes )))
197+
198+ if uuid_str == DEVICE_NAME_UUID :
199+ self ._device_name = bytes (
200+ bos_descriptor [ofst + 20 : ofst + len ]
201+ ).decode ()
202+ print ("Connected to hub '" + self ._device_name + "'" )
203+
204+ elif uuid_str == FW_REV_UUID :
205+ fw_version = bytes (bos_descriptor [ofst + 20 : ofst + len ])
206+ self ._fw_version = Version (fw_version .decode ())
207+
208+ elif uuid_str == SW_REV_UUID :
209+ protocol_version = bytes (bos_descriptor [ofst + 20 : ofst + len ])
210+ self ._protocol_version = semver .VersionInfo .parse (
211+ protocol_version .decode ()
212+ )
213+
214+ elif uuid_str == PYBRICKS_HUB_CAPABILITIES_UUID :
215+ caps = bytes (bos_descriptor [ofst + 20 : ofst + len ])
216+ (
217+ self ._max_write_size ,
218+ self ._capability_flags ,
219+ self ._max_user_program_size ,
220+ ) = unpack_hub_capabilities (caps )
221+
222+ ofst += len
223+
224+ self ._monitor_task = asyncio .create_task (self ._monitor_usb ())
225+
226+ async def disconnect (self ) -> None :
227+ # FIXME: Need to make sure this is called when the USB cable is unplugged
228+ self ._monitor_task .cancel ()
229+ self ._disconnected_callback ()
230+
231+ async def get_firmware_version (self ) -> Version :
232+ return self ._fw_version
233+
234+ async def get_protocol_version (self ) -> Version :
235+ return self ._protocol_version
236+
237+ async def get_hub_type (self ) -> Tuple [HubKind , int ]:
238+ hub_types = {
239+ LegoUsbPid .SPIKE_PRIME : (HubKind .TECHNIC_LARGE , 0 ),
240+ LegoUsbPid .ROBOT_INVENTOR : (HubKind .TECHNIC_LARGE , 1 ),
241+ LegoUsbPid .SPIKE_ESSENTIAL : (HubKind .TECHNIC_SMALL , 0 ),
242+ }
243+
244+ return hub_types [self ._device .idProduct ]
245+
246+ async def get_hub_capabilities (self ) -> Tuple [int , HubCapabilityFlag , int ]:
247+ return (
248+ self ._max_write_size ,
249+ self ._capability_flags ,
250+ self ._max_user_program_size ,
251+ )
252+
253+ async def send_command (self , command : bytes ) -> None :
254+ self ._ep_out .write (UUID (PYBRICKS_COMMAND_EVENT_UUID ).bytes_le + command )
255+
256+ async def set_service_handler (self , callback : Callable ) -> None :
257+ self ._notify_callbacks [PYBRICKS_COMMAND_EVENT_UUID ] = callback
258+
259+ async def _monitor_usb (self ):
260+ loop = asyncio .get_running_loop ()
261+
262+ while True :
263+ msg = await loop .run_in_executor (None , self ._read_usb )
264+
265+ if msg is None :
266+ continue
267+
268+ if len (msg ) > 16 :
269+ uuid = str (UUID (bytes_le = bytes (msg [:16 ])))
270+ if uuid in self ._notify_callbacks :
271+ callback = self ._notify_callbacks [uuid ]
272+ if callback :
273+ callback (bytes (msg [16 :]))
274+
275+ def _read_usb (self ):
276+ try :
277+ msg = self ._ep_in .read (self ._ep_in .wMaxPacketSize )
278+ return msg
279+ except USBTimeoutError :
280+ return None
281+
282+
141283class PybricksHub :
142284 EOL = b"\r \n " # MicroPython EOL
143285
@@ -326,11 +468,12 @@ def _pybricks_service_handler(self, data: bytes) -> None:
326468 if self ._enable_line_handler :
327469 self ._handle_line_data (payload )
328470
329- async def connect (self , device : BLEDevice ):
471+ async def connect (self , device ):
330472 """Connects to a device that was discovered with :meth:`pybricksdev.ble.find_device`
473+ or :meth:`usb.core.find`
331474
332475 Args:
333- device: The device to connect to.
476+ device: The device to connect to (`BLEDevice` or `USBDevice`) .
334477
335478 Raises:
336479 BleakError: if connecting failed (or old firmware without Device
@@ -350,7 +493,12 @@ async def connect(self, device: BLEDevice):
350493 self .connection_state_observable .on_next , ConnectionState .DISCONNECTED
351494 )
352495
353- self ._transport = _BLETransport (device )
496+ if isinstance (device , BLEDevice ):
497+ self ._transport = _BLETransport (device )
498+ elif isinstance (device , USBDevice ):
499+ self ._transport = _USBTransport (device )
500+ else :
501+ raise TypeError ("Unsupported device type" )
354502
355503 def handle_disconnect ():
356504 logger .info ("Disconnected!" )
0 commit comments