88import os
99import struct
1010from typing import Awaitable , Callable , List , Optional , Tuple , TypeVar
11+ from uuid import UUID
1112
1213import reactivex .operators as op
1314import semver
1920from tqdm .auto import tqdm
2021from tqdm .contrib .logging import logging_redirect_tqdm
2122
23+ from usb .control import get_descriptor
24+ from usb .core import Device as USBDevice
25+ from usb .core import Endpoint , USBTimeoutError
26+ from usb .util import ENDPOINT_IN , ENDPOINT_OUT , endpoint_direction , find_descriptor
27+
2228from ..ble .lwp3 .bytecodes import HubKind
2329from ..ble .nus import NUS_RX_UUID , NUS_TX_UUID
2430from ..ble .pybricks import (
31+ DEVICE_NAME_UUID ,
2532 FW_REV_UUID ,
2633 PNP_ID_UUID ,
2734 PYBRICKS_COMMAND_EVENT_UUID ,
@@ -141,6 +148,136 @@ def handler(_, data):
141148 await self ._client .start_notify (PYBRICKS_COMMAND_EVENT_UUID , handler )
142149
143150
151+ class _USBTransport (_Transport ):
152+ _device : USBDevice
153+ _disconnected_callback : Callable
154+ _ep_in : Endpoint
155+ _ep_out : Endpoint
156+ _notify_callbacks = {}
157+ _monitor_task : asyncio .Task
158+
159+ def __init__ (self , device : USBDevice ):
160+ self ._device = device
161+
162+ async def connect (self , disconnected_callback : Callable ) -> None :
163+ self ._disconnected_callback = disconnected_callback
164+ self ._device .set_configuration ()
165+
166+ # Save input and output endpoints
167+ cfg = self ._device .get_active_configuration ()
168+ intf = cfg [(0 , 0 )]
169+ self ._ep_in = find_descriptor (
170+ intf ,
171+ custom_match = lambda e : endpoint_direction (e .bEndpointAddress )
172+ == ENDPOINT_IN ,
173+ )
174+ self ._ep_out = find_descriptor (
175+ intf ,
176+ custom_match = lambda e : endpoint_direction (e .bEndpointAddress )
177+ == ENDPOINT_OUT ,
178+ )
179+
180+ # Get length of BOS descriptor
181+ bos_descriptor = get_descriptor (self ._device , 5 , 0x0F , 0 )
182+ (ofst , bos_len ) = struct .unpack ("<BxHx" , bos_descriptor )
183+
184+ # Get full BOS descriptor
185+ bos_descriptor = get_descriptor (self ._device , bos_len , 0x0F , 0 )
186+
187+ while ofst < bos_len :
188+ (len , desc_type , cap_type ) = struct .unpack_from (
189+ "<BBB" , bos_descriptor , offset = ofst
190+ )
191+
192+ if desc_type != 0x10 :
193+ raise Exception ("Expected Device Capability descriptor" )
194+
195+ # Look for platform descriptors
196+ if cap_type == 0x05 :
197+ uuid_bytes = bos_descriptor [ofst + 4 : ofst + 4 + 16 ]
198+ uuid_str = str (UUID (bytes_le = bytes (uuid_bytes )))
199+
200+ if uuid_str == DEVICE_NAME_UUID :
201+ self ._device_name = bytearray (
202+ bos_descriptor [ofst + 20 : ofst + len ]
203+ ).decode ()
204+ print ("Connected to hub '" + self ._device_name + "'" )
205+
206+ elif uuid_str == FW_REV_UUID :
207+ fw_version = bytearray (bos_descriptor [ofst + 20 : ofst + len ])
208+ self ._fw_version = Version (fw_version .decode ())
209+
210+ elif uuid_str == SW_REV_UUID :
211+ protocol_version = bytearray (bos_descriptor [ofst + 20 : ofst + len ])
212+ self ._protocol_version = semver .VersionInfo .parse (
213+ protocol_version .decode ()
214+ )
215+
216+ elif uuid_str == PYBRICKS_HUB_CAPABILITIES_UUID :
217+ caps = bytearray (bos_descriptor [ofst + 20 : ofst + len ])
218+ (
219+ self ._max_write_size ,
220+ self ._capability_flags ,
221+ self ._max_user_program_size ,
222+ ) = unpack_hub_capabilities (caps )
223+
224+ ofst += len
225+
226+ self ._monitor_task = asyncio .create_task (self ._monitor_usb ())
227+
228+ async def disconnect (self ) -> None :
229+ self ._monitor_task .cancel ()
230+ self ._disconnected_callback ()
231+
232+ async def get_firmware_version (self ) -> Version :
233+ return self ._fw_version
234+
235+ async def get_protocol_version (self ) -> Version :
236+ return self ._protocol_version
237+
238+ async def get_hub_type (self ) -> Tuple [HubKind , int ]:
239+ return (HubKind .TECHNIC_LARGE , 0 )
240+
241+ async def get_hub_capabilities (self ) -> Tuple [int , HubCapabilityFlag , int ]:
242+ return (
243+ self ._max_write_size ,
244+ self ._capability_flags ,
245+ self ._max_user_program_size ,
246+ )
247+
248+ async def send_command (self , command : bytes ) -> None :
249+ self ._ep_out .write (UUID (PYBRICKS_COMMAND_EVENT_UUID ).bytes_le + command )
250+
251+ async def set_nus_handler (self , callback : Callable ) -> None :
252+ pass
253+
254+ async def set_service_handler (self , callback : Callable ) -> None :
255+ self ._notify_callbacks [PYBRICKS_COMMAND_EVENT_UUID ] = callback
256+
257+ async def _monitor_usb (self ):
258+ loop = asyncio .get_running_loop ()
259+
260+ while True :
261+ msg = await loop .run_in_executor (None , self ._read_usb )
262+
263+ if msg is None :
264+ continue
265+
266+ if len (msg ) > 16 :
267+ uuid = str (UUID (bytes_le = bytes (msg [:16 ])))
268+ if uuid in self ._notify_callbacks :
269+ callback = self ._notify_callbacks [uuid ]
270+ if callback :
271+ callback (bytes (msg [16 :]))
272+
273+ def _read_usb (self ):
274+ try :
275+ msg = self ._ep_in .read (self ._ep_in .wMaxPacketSize )
276+ return msg
277+ except USBTimeoutError :
278+ return None
279+
280+
144281class PybricksHub :
145282 EOL = b"\r \n " # MicroPython EOL
146283
@@ -329,7 +466,7 @@ def _pybricks_service_handler(self, data: bytes) -> None:
329466 if self ._enable_line_handler :
330467 self ._handle_line_data (payload )
331468
332- async def connect (self , device : BLEDevice ):
469+ async def connect (self , device ):
333470 if self .connection_state_observable .value != ConnectionState .DISCONNECTED :
334471 raise RuntimeError (
335472 f"attempting to connect with invalid state: { self .connection_state_observable .value } "
@@ -342,7 +479,12 @@ async def connect(self, device: BLEDevice):
342479 self .connection_state_observable .on_next , ConnectionState .DISCONNECTED
343480 )
344481
345- self ._transport = _BLETransport (device )
482+ if isinstance (device , BLEDevice ):
483+ self ._transport = _BLETransport (device )
484+ elif isinstance (device , USBDevice ):
485+ self ._transport = _USBTransport (device )
486+ else :
487+ raise TypeError ("Unsupported device type" )
346488
347489 def handle_disconnect ():
348490 logger .info ("Disconnected!" )
0 commit comments