diff --git a/src/uipath_mcp/_cli/_runtime/_runtime.py b/src/uipath_mcp/_cli/_runtime/_runtime.py index 24ab8de..f03111b 100644 --- a/src/uipath_mcp/_cli/_runtime/_runtime.py +++ b/src/uipath_mcp/_cli/_runtime/_runtime.py @@ -98,30 +98,39 @@ async def execute(self) -> Optional[UiPathRuntimeResult]: await self._register() run_task = asyncio.create_task(self._signalr_client.run()) - - # Set up a task to wait for cancellation cancel_task = asyncio.create_task(self._cancel_event.wait()) - self._keep_alive_task = asyncio.create_task(self._keep_alive()) - # Keep the runtime alive - # Wait for either the run to complete or cancellation - done, pending = await asyncio.wait( - [run_task, cancel_task], return_when=asyncio.FIRST_COMPLETED - ) - - # Cancel any pending tasks - for task in pending: - task.cancel() + try: + # Wait for either the run to complete or cancellation + done, pending = await asyncio.wait( + [run_task, cancel_task], return_when=asyncio.FIRST_COMPLETED + ) + except KeyboardInterrupt: + logger.info( + "Received keyboard interrupt, shutting down gracefully..." + ) + self._cancel_event.set() + finally: + # Cancel any pending tasks gracefully + for task in [run_task, cancel_task, self._keep_alive_task]: + if task and not task.done(): + task.cancel() + try: + await asyncio.wait_for(task, timeout=2.0) + except (asyncio.CancelledError, asyncio.TimeoutError): + pass output_result = {} if self._session_output: output_result["content"] = self._session_output self.context.result = UiPathRuntimeResult(output=output_result) - return self.context.result + except KeyboardInterrupt: + logger.info("Keyboard interrupt received") + return None except Exception as e: if isinstance(e, UiPathMcpRuntimeError): raise @@ -133,7 +142,9 @@ async def execute(self) -> Optional[UiPathRuntimeResult]: UiPathErrorCategory.USER, ) from e finally: - self.trace_provider.shutdown() + await self.cleanup() + if hasattr(self, "trace_provider") and self.trace_provider: + self.trace_provider.shutdown() async def validate(self) -> None: """Validate runtime inputs and load MCP server configuration.""" @@ -442,36 +453,49 @@ async def _keep_alive(self) -> None: """ Heartbeat to keep the runtime available. """ - while not self._cancel_event.is_set(): - try: + try: + while not self._cancel_event.is_set(): + try: + + async def on_keep_alive_response( + response: CompletionMessage, + ) -> None: + if response.error: + logger.error(f"Error during keep-alive: {response.error}") + return + session_ids = response.result + logger.info(f"Active sessions: {session_ids}") + # If there are no active sessions and this is a sandbox environment + # We need to cancel the runtime + # eg: when user kills the agent that triggered the runtime, before we subscribe to events + if ( + not session_ids + and self.sandboxed + and not self._cancel_event.is_set() + ): + logger.error( + "No active sessions, cancelling sandboxed runtime..." + ) + self._cancel_event.set() - async def on_keep_alive_response(response: CompletionMessage) -> None: - if response.error: - logger.error(f"Error during keep-alive: {response.error}") - return - session_ids = response.result - logger.info(f"Active sessions: {session_ids}") - # If there are no active sessions and this is a sandbox environment - # We need to cancel the runtime - # eg: when user kills the agent that triggered the runtime, before we subscribe to events - if ( - not session_ids - and self.sandboxed - and not self._cancel_event.is_set() - ): - logger.error( - "No active sessions, cancelling sandboxed runtime..." + if self._signalr_client: + await self._signalr_client.send( + method="OnKeepAlive", + arguments=[], + on_invocation=on_keep_alive_response, ) - self._cancel_event.set() + except Exception as e: + if not self._cancel_event.is_set(): + logger.error(f"Error during keep-alive: {e}") - await self._signalr_client.send( - method="OnKeepAlive", - arguments=[], - on_invocation=on_keep_alive_response, - ) - except Exception as e: - logger.error(f"Error during keep-alive: {e}") - await asyncio.sleep(60) + try: + await asyncio.wait_for(self._cancel_event.wait(), timeout=60) + break + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + logger.info("Keep-alive task cancelled") + raise async def _on_runtime_abort(self) -> None: """