Skip to content

Commit 11637be

Browse files
committed
Address a potential issue in the PybricksHubUSB.write_gatt_char method. Previously, there was a concern that this part of the code could get stuck if the USB hub disconnected or didn't send a response.
Here's how I've addressed it: - I've added a 5-second timeout for waiting for a response from the hub. - I'm now also monitoring for a hub disconnection while waiting for the response. If the hub disconnects, a `RuntimeError` will occur. If the operation times out, an `asyncio.TimeoutError` will occur. I've also included some checks in `tests/connections/test_pybricks.py` to ensure this new behavior works as expected in both disconnection and timeout situations.
1 parent e0c8dad commit 11637be

File tree

2 files changed

+103
-5
lines changed

2 files changed

+103
-5
lines changed

pybricksdev/connections/pybricks.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -837,10 +837,17 @@ async def write_gatt_char(self, uuid: str, data, response: bool) -> None:
837837
raise ValueError("Response is required for USB")
838838

839839
self._ep_out.write(bytes([PybricksUsbOutEpMessageType.COMMAND]) + data)
840-
# FIXME: This needs to race with hub disconnect, and could also use a
841-
# timeout, otherwise it blocks forever. Pyusb doesn't currently seem to
842-
# have any disconnect callback.
843-
reply = await self._response_queue.get()
840+
841+
try:
842+
reply = await asyncio.wait_for(
843+
self.race_disconnect(self._response_queue.get()),
844+
timeout=5.0, # 5-second timeout
845+
)
846+
except asyncio.TimeoutError:
847+
# Handle timeout specifically if needed, or let race_disconnect handle it
848+
# For now, let's make it explicit
849+
raise asyncio.TimeoutError("Timeout waiting for USB response")
850+
# race_disconnect will raise RuntimeError if disconnected
844851

845852
# REVISIT: could look up status error code and convert to string,
846853
# although BLE doesn't do that either.

tests/connections/test_pybricks.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,21 @@
44
import contextlib
55
import os
66
import tempfile
7-
from unittest.mock import AsyncMock, PropertyMock, patch
7+
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
88

99
import pytest
1010
from reactivex.subject import Subject
1111

12+
from pybricksdev.ble.pybricks import PYBRICKS_COMMAND_EVENT_UUID
1213
from pybricksdev.connections.pybricks import (
1314
ConnectionState,
1415
HubCapabilityFlag,
1516
HubKind,
1617
PybricksHubBLE,
18+
PybricksHubUSB,
1719
StatusFlag,
1820
)
21+
from pybricksdev.usb.pybricks import PybricksUsbOutEpMessageType
1922

2023

2124
class TestPybricksHub:
@@ -180,3 +183,91 @@ async def test_run_modern_protocol(self):
180183
# Verify the expected calls were made
181184
hub.download_user_program.assert_called_once()
182185
hub.start_user_program.assert_called_once()
186+
187+
188+
class TestPybricksHubUSB:
189+
"""Tests for the PybricksHubUSB class functionality."""
190+
191+
@pytest.mark.asyncio
192+
async def test_pybricks_hub_usb_write_gatt_char_disconnect(self):
193+
"""Test write_gatt_char when a disconnect event occurs."""
194+
hub = PybricksHubUSB(MagicMock())
195+
196+
hub._ep_out = MagicMock()
197+
# Simulate _response_queue.get() blocking indefinitely
198+
hub._response_queue = AsyncMock()
199+
hub._response_queue.get = AsyncMock(side_effect=asyncio.Event().wait)
200+
201+
mock_observable = MagicMock(
202+
spec=Subject
203+
) # Using Subject as a base for mock spec
204+
disconnect_callback_handler = None
205+
206+
def mock_subscribe_side_effect(on_next_callback, *args, **kwargs):
207+
nonlocal disconnect_callback_handler
208+
disconnect_callback_handler = on_next_callback
209+
mock_subscription = MagicMock()
210+
mock_subscription.dispose = MagicMock()
211+
return mock_subscription
212+
213+
mock_observable.subscribe = MagicMock(side_effect=mock_subscribe_side_effect)
214+
type(hub.connection_state_observable).value = PropertyMock(
215+
return_value=ConnectionState.CONNECTED
216+
)
217+
hub.connection_state_observable = mock_observable
218+
219+
async def trigger_disconnect_event():
220+
await asyncio.sleep(0.05)
221+
assert (
222+
disconnect_callback_handler is not None
223+
), "Subscribe was not called by race_disconnect"
224+
disconnect_callback_handler(ConnectionState.DISCONNECTED)
225+
226+
with pytest.raises(RuntimeError, match="disconnected during operation"):
227+
await asyncio.gather(
228+
hub.write_gatt_char(PYBRICKS_COMMAND_EVENT_UUID, b"test_data", True),
229+
trigger_disconnect_event(),
230+
)
231+
232+
hub._ep_out.write.assert_called_once_with(
233+
bytes([PybricksUsbOutEpMessageType.COMMAND]) + b"test_data"
234+
)
235+
236+
@pytest.mark.asyncio
237+
async def test_pybricks_hub_usb_write_gatt_char_timeout(self):
238+
"""Test write_gatt_char when a timeout occurs."""
239+
hub = PybricksHubUSB(MagicMock())
240+
241+
hub._ep_out = MagicMock()
242+
hub._response_queue = AsyncMock()
243+
# Make _response_queue.get() block indefinitely
244+
hub._response_queue.get = AsyncMock(side_effect=asyncio.Event().wait)
245+
246+
mock_observable = MagicMock(spec=Subject)
247+
248+
def mock_subscribe_side_effect(on_next_callback, *args, **kwargs):
249+
mock_subscription = MagicMock()
250+
mock_subscription.dispose = MagicMock()
251+
return mock_subscription
252+
253+
mock_observable.subscribe = MagicMock(side_effect=mock_subscribe_side_effect)
254+
type(hub.connection_state_observable).value = PropertyMock(
255+
return_value=ConnectionState.CONNECTED
256+
)
257+
hub.connection_state_observable = mock_observable
258+
259+
# The method has a hardcoded timeout of 5.0s.
260+
# We can patch asyncio.wait_for to speed up the test.
261+
with patch(
262+
"asyncio.wait_for", side_effect=asyncio.TimeoutError("Test-induced timeout")
263+
):
264+
with pytest.raises(
265+
asyncio.TimeoutError, match="Timeout waiting for USB response"
266+
):
267+
await hub.write_gatt_char(
268+
PYBRICKS_COMMAND_EVENT_UUID, b"test_data", True
269+
)
270+
271+
hub._ep_out.write.assert_called_once_with(
272+
bytes([PybricksUsbOutEpMessageType.COMMAND]) + b"test_data"
273+
)

0 commit comments

Comments
 (0)