Skip to content

Commit 2b2ace0

Browse files
authored
Merge pull request #1 from shaggysa/stay-connected
Fixes to the stay-connected feature
2 parents f76600c + 888aa70 commit 2b2ace0

File tree

2 files changed

+147
-21
lines changed

2 files changed

+147
-21
lines changed

pybricksdev/cli/__init__.py

Lines changed: 91 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
from pybricksdev import __name__ as MODULE_NAME
2222
from pybricksdev import __version__ as MODULE_VERSION
23+
from pybricksdev.ble.pybricks import StatusFlag
24+
from pybricksdev.connections import ConnectionState
2325

2426
PROG_NAME = (
2527
f"{path.basename(sys.executable)} -m {MODULE_NAME}"
@@ -221,25 +223,95 @@ def is_pybricks_usb(dev):
221223
# Connect to the address and run the script
222224
await hub.connect()
223225
try:
224-
while True:
225-
with _get_script_path(args.file) as script_path:
226-
if args.start:
227-
await hub.run(script_path, args.wait)
228-
else:
229-
await hub.download(script_path)
230-
231-
if not args.wait or not args.stay_connected:
232-
break
233-
234-
resend = await questionary.select(
235-
"Would you like to resend your code?", choices=["Resend", "Exit"]
236-
).ask_async()
237-
238-
if resend == "Exit":
239-
break
240-
241-
except RuntimeError:
242-
print("The hub is no longer connected.")
226+
with _get_script_path(args.file) as script_path:
227+
if args.start:
228+
await hub.run(script_path, args.wait or args.stay_connected)
229+
else:
230+
if args.stay_connected:
231+
hub.print_output = True
232+
hub._enable_line_handler = True
233+
await hub.download(script_path)
234+
235+
if args.stay_connected:
236+
response_options = [
237+
"Recompile and Run",
238+
"Recompile and Download",
239+
"Exit",
240+
]
241+
while True:
242+
try:
243+
response = await hub.race_power_button_press(
244+
questionary.select(
245+
"Would you like to re-compile your code?",
246+
response_options,
247+
).ask_async()
248+
)
249+
except RuntimeError as e:
250+
251+
async def reconnect_hub():
252+
if await questionary.confirm(
253+
"\nThe hub has been disconnected. Would you like to re-connect?"
254+
).ask_async():
255+
if args.conntype == "ble":
256+
print(
257+
f"Searching for {args.name or 'any hub with Pybricks service'}..."
258+
)
259+
device_or_address = await find_ble(args.name)
260+
hub = PybricksHubBLE(device_or_address)
261+
elif args.conntype == "usb":
262+
device_or_address = find_usb(
263+
custom_match=is_pybricks_usb
264+
)
265+
hub = PybricksHubUSB(device_or_address)
266+
267+
await hub.connect()
268+
# re-enable echoing of the hub's stdout
269+
hub.print_output = True
270+
hub._enable_line_handler = True
271+
return hub
272+
273+
else:
274+
exit()
275+
276+
if (
277+
hub.status_observable.value
278+
& StatusFlag.POWER_BUTTON_PRESSED
279+
):
280+
try:
281+
await hub._wait_for_user_program_stop(2.1)
282+
continue
283+
284+
except RuntimeError as e:
285+
if (
286+
hub.connection_state_observable.value
287+
== ConnectionState.DISCONNECTED
288+
):
289+
hub = await reconnect_hub()
290+
continue
291+
292+
else:
293+
raise e
294+
295+
elif (
296+
hub.connection_state_observable.value
297+
== ConnectionState.DISCONNECTED
298+
):
299+
# let terminal cool off before making a new prompt
300+
await asyncio.sleep(0.3)
301+
302+
hub = await reconnect_hub()
303+
continue
304+
305+
else:
306+
raise e
307+
308+
with _get_script_path(args.file) as script_path:
309+
if response == response_options[0]:
310+
await hub.run(script_path, True)
311+
elif response == response_options[1]:
312+
await hub.download(script_path)
313+
else:
314+
exit(1)
243315

244316
finally:
245317
await hub.disconnect()

pybricksdev/connections/pybricks.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,60 @@ async def send_block(data: bytes) -> None:
678678
if wait:
679679
await self._wait_for_user_program_stop()
680680

681-
async def _wait_for_user_program_stop(self):
681+
async def race_power_button_press(self, awaitable: Awaitable[T]) -> T:
682+
"""
683+
Races an awaitable against the user pressing the power button of the hub.
684+
685+
If the power button is pressed or the hub becomes disconnected before the awaitable is complete, a
686+
``RuntimeError`` is raised and the awaitable is canceled.
687+
688+
Otherwise, the result of the awaitable is returned. If the awaitable
689+
raises an exception, that exception will be raised.
690+
691+
Args:
692+
awaitable: Any awaitable such as a coroutine.
693+
694+
Returns:
695+
The result of the awaitable.
696+
697+
Raises:
698+
RuntimeError:
699+
Thrown if the hub's power button is pressed or the hub is disconnected.
700+
"""
701+
awaitable_task = asyncio.ensure_future(awaitable)
702+
703+
power_button_press_event = asyncio.Event()
704+
power_button_press_task = asyncio.ensure_future(power_button_press_event.wait())
705+
706+
disconnect_event = asyncio.Event()
707+
disconnect_task = asyncio.ensure_future(disconnect_event.wait())
708+
709+
def handle_disconnect(state: ConnectionState):
710+
if state == ConnectionState.DISCONNECTED:
711+
disconnect_event.set()
712+
713+
def handle_power_button_press(status: StatusFlag):
714+
if status.value & StatusFlag.POWER_BUTTON_PRESSED:
715+
power_button_press_event.set()
716+
717+
with self.status_observable.subscribe(
718+
handle_power_button_press
719+
) and self.connection_state_observable.subscribe(handle_disconnect):
720+
done, pending = await asyncio.wait(
721+
{awaitable_task, power_button_press_task, disconnect_task},
722+
return_when=asyncio.FIRST_COMPLETED,
723+
)
724+
725+
for t in pending:
726+
t.cancel()
727+
728+
if power_button_press_task in done:
729+
raise RuntimeError("the hub's power button was pressed during operation")
730+
elif disconnect_task in done:
731+
raise RuntimeError("the hub was disconnected during operation")
732+
return awaitable_task.result()
733+
734+
async def _wait_for_user_program_stop(self, program_start_timeout=1):
682735
user_program_running: asyncio.Queue[bool] = asyncio.Queue()
683736

684737
with self.status_observable.pipe(
@@ -695,7 +748,8 @@ async def _wait_for_user_program_stop(self):
695748
# for it to start
696749
try:
697750
await asyncio.wait_for(
698-
self.race_disconnect(user_program_running.get()), 1
751+
self.race_disconnect(user_program_running.get()),
752+
program_start_timeout,
699753
)
700754
except asyncio.TimeoutError:
701755
# if it doesn't start, assume it was a very short lived

0 commit comments

Comments
 (0)