88import os
99import struct
1010from typing import Awaitable , Callable , List , Optional , Tuple , TypeVar
11+ from uuid import UUID
1112
1213import reactivex .operators as op
1314import semver
1920from tqdm .auto import tqdm
2021from tqdm .contrib .logging import logging_redirect_tqdm
2122
23+ from usb .control import get_descriptor
24+ from usb .core import Device as USBDevice
25+ from usb .core import Endpoint , USBTimeoutError
26+ from usb .util import ENDPOINT_IN , ENDPOINT_OUT , endpoint_direction , find_descriptor
27+
2228from ..ble .lwp3 .bytecodes import HubKind
2329from ..ble .nus import NUS_RX_UUID , NUS_TX_UUID
2430from ..ble .pybricks import (
31+ DEVICE_NAME_UUID ,
2532 FW_REV_UUID ,
2633 PNP_ID_UUID ,
2734 PYBRICKS_COMMAND_EVENT_UUID ,
3845from ..compile import compile_file , compile_multi_file
3946from ..tools import chunk
4047from ..tools .checksum import xor_bytes
48+ from ..usb import LegoUsbPid
4149from . import ConnectionState
4250
4351logger = logging .getLogger (__name__ )
@@ -138,6 +146,158 @@ def handler(_, data):
138146 await self ._client .start_notify (NUS_TX_UUID , handler )
139147
140148
149+ class _USBTransport (_Transport ):
150+ _device : USBDevice
151+ _disconnected_callback : Callable
152+ _ep_in : Endpoint
153+ _ep_out : Endpoint
154+ _notify_callbacks = {}
155+ _monitor_task : asyncio .Task
156+ _response_event = asyncio .Event ()
157+ _response : int
158+
159+ _USB_PYBRICKS_MSG_COMMAND = b"\x00 "
160+ _USB_PYBRICKS_MSG_COMMAND_RESPONSE = b"\x01 "
161+ _USB_PYBRICKS_MSG_EVENT = b"\x02 "
162+
163+ def __init__ (self , device : USBDevice ):
164+ self ._device = device
165+ self ._notify_callbacks [
166+ self ._USB_PYBRICKS_MSG_COMMAND_RESPONSE [0 ]
167+ ] = self ._response_handler
168+
169+ async def connect (self , disconnected_callback : Callable ) -> None :
170+ self ._disconnected_callback = disconnected_callback
171+ self ._device .set_configuration ()
172+
173+ # Save input and output endpoints
174+ cfg = self ._device .get_active_configuration ()
175+ intf = cfg [(0 , 0 )]
176+ self ._ep_in = find_descriptor (
177+ intf ,
178+ custom_match = lambda e : endpoint_direction (e .bEndpointAddress )
179+ == ENDPOINT_IN ,
180+ )
181+ self ._ep_out = find_descriptor (
182+ intf ,
183+ custom_match = lambda e : endpoint_direction (e .bEndpointAddress )
184+ == ENDPOINT_OUT ,
185+ )
186+
187+ # Get length of BOS descriptor
188+ bos_descriptor = get_descriptor (self ._device , 5 , 0x0F , 0 )
189+ (ofst , bos_len ) = struct .unpack ("<BxHx" , bos_descriptor )
190+
191+ # Get full BOS descriptor
192+ bos_descriptor = get_descriptor (self ._device , bos_len , 0x0F , 0 )
193+
194+ while ofst < bos_len :
195+ (len , desc_type , cap_type ) = struct .unpack_from (
196+ "<BBB" , bos_descriptor , offset = ofst
197+ )
198+
199+ if desc_type != 0x10 :
200+ raise Exception ("Expected Device Capability descriptor" )
201+
202+ # Look for platform descriptors
203+ if cap_type == 0x05 :
204+ uuid_bytes = bos_descriptor [ofst + 4 : ofst + 4 + 16 ]
205+ uuid_str = str (UUID (bytes_le = bytes (uuid_bytes )))
206+
207+ if uuid_str == DEVICE_NAME_UUID :
208+ self ._device_name = bytes (
209+ bos_descriptor [ofst + 20 : ofst + len ]
210+ ).decode ()
211+ print ("Connected to hub '" + self ._device_name + "'" )
212+
213+ elif uuid_str == FW_REV_UUID :
214+ fw_version = bytes (bos_descriptor [ofst + 20 : ofst + len ])
215+ self ._fw_version = Version (fw_version .decode ())
216+
217+ elif uuid_str == SW_REV_UUID :
218+ protocol_version = bytes (bos_descriptor [ofst + 20 : ofst + len ])
219+ self ._protocol_version = semver .VersionInfo .parse (
220+ protocol_version .decode ()
221+ )
222+
223+ elif uuid_str == PYBRICKS_HUB_CAPABILITIES_UUID :
224+ caps = bytes (bos_descriptor [ofst + 20 : ofst + len ])
225+ (
226+ self ._max_write_size ,
227+ self ._capability_flags ,
228+ self ._max_user_program_size ,
229+ ) = unpack_hub_capabilities (caps )
230+
231+ ofst += len
232+
233+ self ._monitor_task = asyncio .create_task (self ._monitor_usb ())
234+
235+ async def disconnect (self ) -> None :
236+ # FIXME: Need to make sure this is called when the USB cable is unplugged
237+ self ._monitor_task .cancel ()
238+ self ._disconnected_callback ()
239+
240+ async def get_firmware_version (self ) -> Version :
241+ return self ._fw_version
242+
243+ async def get_protocol_version (self ) -> Version :
244+ return self ._protocol_version
245+
246+ async def get_hub_type (self ) -> Tuple [HubKind , int ]:
247+ hub_types = {
248+ LegoUsbPid .SPIKE_PRIME : (HubKind .TECHNIC_LARGE , 0 ),
249+ LegoUsbPid .ROBOT_INVENTOR : (HubKind .TECHNIC_LARGE , 1 ),
250+ LegoUsbPid .SPIKE_ESSENTIAL : (HubKind .TECHNIC_SMALL , 0 ),
251+ }
252+
253+ return hub_types [self ._device .idProduct ]
254+
255+ async def get_hub_capabilities (self ) -> Tuple [int , HubCapabilityFlag , int ]:
256+ return (
257+ self ._max_write_size ,
258+ self ._capability_flags ,
259+ self ._max_user_program_size ,
260+ )
261+
262+ async def send_command (self , command : bytes ) -> None :
263+ self ._response = None
264+ self ._response_event .clear ()
265+ self ._ep_out .write (self ._USB_PYBRICKS_MSG_COMMAND + command )
266+ try :
267+ await asyncio .wait_for (self ._response_event .wait (), 1 )
268+ if self ._response != 0 :
269+ print (f"Received error response for command: { self ._response } " )
270+ except asyncio .TimeoutError :
271+ print ("Timed out waiting for a response" )
272+
273+ async def set_service_handler (self , callback : Callable ) -> None :
274+ self ._notify_callbacks [self ._USB_PYBRICKS_MSG_EVENT [0 ]] = callback
275+
276+ async def _monitor_usb (self ):
277+ loop = asyncio .get_running_loop ()
278+
279+ while True :
280+ msg = await loop .run_in_executor (None , self ._read_usb )
281+
282+ if msg is None or len (msg ) == 0 :
283+ continue
284+
285+ if msg [0 ] in self ._notify_callbacks :
286+ callback = self ._notify_callbacks [msg [0 ]]
287+ callback (bytes (msg [1 :]))
288+
289+ def _read_usb (self ):
290+ try :
291+ msg = self ._ep_in .read (self ._ep_in .wMaxPacketSize )
292+ return msg
293+ except USBTimeoutError :
294+ return None
295+
296+ def _response_handler (self , data : bytes ) -> None :
297+ (self ._response ,) = struct .unpack ("<I" , data )
298+ self ._response_event .set ()
299+
300+
141301class PybricksHub :
142302 EOL = b"\r \n " # MicroPython EOL
143303
@@ -326,11 +486,12 @@ def _pybricks_service_handler(self, data: bytes) -> None:
326486 if self ._enable_line_handler :
327487 self ._handle_line_data (payload )
328488
329- async def connect (self , device : BLEDevice ):
489+ async def connect (self , device ):
330490 """Connects to a device that was discovered with :meth:`pybricksdev.ble.find_device`
491+ or :meth:`usb.core.find`
331492
332493 Args:
333- device: The device to connect to.
494+ device: The device to connect to (`BLEDevice` or `USBDevice`) .
334495
335496 Raises:
336497 BleakError: if connecting failed (or old firmware without Device
@@ -350,7 +511,12 @@ async def connect(self, device: BLEDevice):
350511 self .connection_state_observable .on_next , ConnectionState .DISCONNECTED
351512 )
352513
353- self ._transport = _BLETransport (device )
514+ if isinstance (device , BLEDevice ):
515+ self ._transport = _BLETransport (device )
516+ elif isinstance (device , USBDevice ):
517+ self ._transport = _USBTransport (device )
518+ else :
519+ raise TypeError ("Unsupported device type" )
354520
355521 def handle_disconnect ():
356522 logger .info ("Disconnected!" )
0 commit comments