77import logging
88import os
99import struct
10- from typing import Awaitable , Callable , List , Optional , Tuple , TypeVar
10+ from typing import Awaitable , Callable , List , Optional , Tuple , TypeVar , Union
11+ from uuid import UUID
1112
1213import reactivex .operators as op
1314import semver
1920from tqdm .auto import tqdm
2021from tqdm .contrib .logging import logging_redirect_tqdm
2122
23+ from usb .control import get_descriptor
24+ from usb .core import Device as USBDevice
25+ from usb .core import Endpoint , USBTimeoutError
26+ from usb .util import ENDPOINT_IN , ENDPOINT_OUT , endpoint_direction , find_descriptor
27+
2228from ..ble .lwp3 .bytecodes import HubKind
2329from ..ble .nus import NUS_RX_UUID , NUS_TX_UUID
2430from ..ble .pybricks import (
31+ DEVICE_NAME_UUID ,
2532 FW_REV_UUID ,
2633 PNP_ID_UUID ,
2734 PYBRICKS_COMMAND_EVENT_UUID ,
3845from ..compile import compile_file , compile_multi_file
3946from ..tools import chunk
4047from ..tools .checksum import xor_bytes
48+ from ..usb import LegoUsbMsg , LegoUsbPid
4149from . import ConnectionState
4250
4351logger = logging .getLogger (__name__ )
@@ -138,6 +146,152 @@ def handler(_, data):
138146 await self ._client .start_notify (NUS_TX_UUID , handler )
139147
140148
149+ class _USBTransport (_Transport ):
150+ _device : USBDevice
151+ _disconnected_callback : Callable
152+ _ep_in : Endpoint
153+ _ep_out : Endpoint
154+ _notify_callbacks = {}
155+ _monitor_task : asyncio .Task
156+ _response : asyncio .Future
157+
158+ def __init__ (self , device : USBDevice ):
159+ self ._device = device
160+ self ._notify_callbacks [
161+ LegoUsbMsg .USB_PYBRICKS_MSG_COMMAND_RESPONSE
162+ ] = self ._response_handler
163+
164+ async def connect (self , disconnected_callback : Callable ) -> None :
165+ self ._disconnected_callback = disconnected_callback
166+ self ._device .set_configuration ()
167+
168+ # Save input and output endpoints
169+ cfg = self ._device .get_active_configuration ()
170+ intf = cfg [(0 , 0 )]
171+ self ._ep_in = find_descriptor (
172+ intf ,
173+ custom_match = lambda e : endpoint_direction (e .bEndpointAddress )
174+ == ENDPOINT_IN ,
175+ )
176+ self ._ep_out = find_descriptor (
177+ intf ,
178+ custom_match = lambda e : endpoint_direction (e .bEndpointAddress )
179+ == ENDPOINT_OUT ,
180+ )
181+
182+ # Get length of BOS descriptor
183+ bos_descriptor = get_descriptor (self ._device , 5 , 0x0F , 0 )
184+ (ofst , bos_len ) = struct .unpack ("<BxHx" , bos_descriptor )
185+
186+ # Get full BOS descriptor
187+ bos_descriptor = get_descriptor (self ._device , bos_len , 0x0F , 0 )
188+
189+ while ofst < bos_len :
190+ (len , desc_type , cap_type ) = struct .unpack_from (
191+ "<BBB" , bos_descriptor , offset = ofst
192+ )
193+
194+ if desc_type != 0x10 :
195+ raise Exception ("Expected Device Capability descriptor" )
196+
197+ # Look for platform descriptors
198+ if cap_type == 0x05 :
199+ uuid_bytes = bos_descriptor [ofst + 4 : ofst + 4 + 16 ]
200+ uuid_str = str (UUID (bytes_le = bytes (uuid_bytes )))
201+
202+ if uuid_str == DEVICE_NAME_UUID :
203+ self ._device_name = bytes (
204+ bos_descriptor [ofst + 20 : ofst + len ]
205+ ).decode ()
206+ print ("Connected to hub '" + self ._device_name + "'" )
207+
208+ elif uuid_str == FW_REV_UUID :
209+ fw_version = bytes (bos_descriptor [ofst + 20 : ofst + len ])
210+ self ._fw_version = Version (fw_version .decode ())
211+
212+ elif uuid_str == SW_REV_UUID :
213+ protocol_version = bytes (bos_descriptor [ofst + 20 : ofst + len ])
214+ self ._protocol_version = semver .VersionInfo .parse (
215+ protocol_version .decode ()
216+ )
217+
218+ elif uuid_str == PYBRICKS_HUB_CAPABILITIES_UUID :
219+ caps = bytes (bos_descriptor [ofst + 20 : ofst + len ])
220+ (
221+ self ._max_write_size ,
222+ self ._capability_flags ,
223+ self ._max_user_program_size ,
224+ ) = unpack_hub_capabilities (caps )
225+
226+ ofst += len
227+
228+ self ._monitor_task = asyncio .create_task (self ._monitor_usb ())
229+
230+ async def disconnect (self ) -> None :
231+ # FIXME: Need to make sure this is called when the USB cable is unplugged
232+ self ._monitor_task .cancel ()
233+ self ._disconnected_callback ()
234+
235+ async def get_firmware_version (self ) -> Version :
236+ return self ._fw_version
237+
238+ async def get_protocol_version (self ) -> Version :
239+ return self ._protocol_version
240+
241+ async def get_hub_type (self ) -> Tuple [HubKind , int ]:
242+ hub_types = {
243+ LegoUsbPid .SPIKE_PRIME : (HubKind .TECHNIC_LARGE , 0 ),
244+ LegoUsbPid .ROBOT_INVENTOR : (HubKind .TECHNIC_LARGE , 1 ),
245+ LegoUsbPid .SPIKE_ESSENTIAL : (HubKind .TECHNIC_SMALL , 0 ),
246+ }
247+
248+ return hub_types [self ._device .idProduct ]
249+
250+ async def get_hub_capabilities (self ) -> Tuple [int , HubCapabilityFlag , int ]:
251+ return (
252+ self ._max_write_size ,
253+ self ._capability_flags ,
254+ self ._max_user_program_size ,
255+ )
256+
257+ async def send_command (self , command : bytes ) -> None :
258+ self ._response = asyncio .Future ()
259+ self ._ep_out .write (
260+ struct .pack ("B" , LegoUsbMsg .USB_PYBRICKS_MSG_COMMAND ) + command
261+ )
262+ try :
263+ await asyncio .wait_for (self ._response , 1 )
264+ if self ._response .result () != 0 :
265+ print (f"Received error response for command: { self ._response .result ()} " )
266+ except asyncio .TimeoutError :
267+ print ("Timed out waiting for a response" )
268+
269+ async def set_service_handler (self , callback : Callable ) -> None :
270+ self ._notify_callbacks [LegoUsbMsg .USB_PYBRICKS_MSG_EVENT ] = callback
271+
272+ async def _monitor_usb (self ):
273+ loop = asyncio .get_running_loop ()
274+
275+ while True :
276+ msg = await loop .run_in_executor (None , self ._read_usb )
277+
278+ if msg is None or len (msg ) == 0 :
279+ continue
280+
281+ callback = self ._notify_callbacks .get (msg [0 ])
282+ if callback is not None :
283+ callback (bytes (msg [1 :]))
284+
285+ def _read_usb (self ):
286+ with contextlib .suppress (USBTimeoutError ):
287+ msg = self ._ep_in .read (self ._ep_in .wMaxPacketSize )
288+ return msg
289+
290+ def _response_handler (self , data : bytes ) -> None :
291+ (response ,) = struct .unpack ("<I" , data )
292+ self ._response .set_result (response )
293+
294+
141295class PybricksHub :
142296 EOL = b"\r \n " # MicroPython EOL
143297
@@ -326,11 +480,12 @@ def _pybricks_service_handler(self, data: bytes) -> None:
326480 if self ._enable_line_handler :
327481 self ._handle_line_data (payload )
328482
329- async def connect (self , device : BLEDevice ):
483+ async def connect (self , device : Union [ BLEDevice , USBDevice ] ):
330484 """Connects to a device that was discovered with :meth:`pybricksdev.ble.find_device`
485+ or :meth:`usb.core.find`
331486
332487 Args:
333- device: The device to connect to.
488+ device: The device to connect to (`BLEDevice` or `USBDevice`) .
334489
335490 Raises:
336491 BleakError: if connecting failed (or old firmware without Device
@@ -350,7 +505,12 @@ async def connect(self, device: BLEDevice):
350505 self .connection_state_observable .on_next , ConnectionState .DISCONNECTED
351506 )
352507
353- self ._transport = _BLETransport (device )
508+ if isinstance (device , BLEDevice ):
509+ self ._transport = _BLETransport (device )
510+ elif isinstance (device , USBDevice ):
511+ self ._transport = _USBTransport (device )
512+ else :
513+ raise TypeError ("Unsupported device type" )
354514
355515 def handle_disconnect ():
356516 logger .info ("Disconnected!" )
0 commit comments