@@ -717,24 +717,20 @@ async def race_power_button_press(self, awaitable: Awaitable[T]) -> T:
717717 power_button_press_event = asyncio .Event ()
718718 power_button_press_task = asyncio .ensure_future (power_button_press_event .wait ())
719719
720- disconnect_event = asyncio .Event ()
721- disconnect_task = asyncio .ensure_future (disconnect_event .wait ())
722-
723- def handle_disconnect (state : ConnectionState ):
724- if state == ConnectionState .DISCONNECTED :
725- disconnect_event .set ()
726-
727720 def handle_power_button_press (status : StatusFlag ):
728721 if status .value & StatusFlag .POWER_BUTTON_PRESSED :
729722 power_button_press_event .set ()
730723
731- with self .status_observable .subscribe (
732- handle_power_button_press
733- ), self .connection_state_observable .subscribe (handle_disconnect ):
734- done , pending = await asyncio .wait (
735- {awaitable_task , power_button_press_task , disconnect_task },
736- return_when = asyncio .FIRST_COMPLETED ,
737- )
724+ with self .status_observable .subscribe (handle_power_button_press ):
725+ try :
726+ done , pending = await asyncio .wait (
727+ {awaitable_task , power_button_press_task },
728+ return_when = asyncio .FIRST_COMPLETED ,
729+ )
730+ except BaseException :
731+ awaitable_task .cancel ()
732+ power_button_press_task .cancel ()
733+ raise
738734
739735 for t in pending :
740736 t .cancel ()
@@ -743,11 +739,11 @@ def handle_power_button_press(status: StatusFlag):
743739 raise HubPowerButtonPressedError (
744740 "the hub's power button was pressed during operation"
745741 )
746- if disconnect_task in done :
747- raise HubDisconnectError ("the hub was disconnected during operation" )
748742 return awaitable_task .result ()
749743
750- async def _wait_for_user_program_stop (self , program_start_timeout = 1 ):
744+ async def _wait_for_user_program_stop (
745+ self , program_start_timeout = 1 , raise_error_on_timeout = False
746+ ):
751747 user_program_running : asyncio .Queue [bool ] = asyncio .Queue ()
752748
753749 with self .status_observable .pipe (
@@ -768,6 +764,8 @@ async def _wait_for_user_program_stop(self, program_start_timeout=1):
768764 program_start_timeout ,
769765 )
770766 except asyncio .TimeoutError :
767+ if raise_error_on_timeout :
768+ raise
771769 # if it doesn't start, assume it was a very short lived
772770 # program and we just missed the status message
773771 logger .debug (
0 commit comments