diff --git a/pyproject.toml b/pyproject.toml index 0370771ff7..56b497e47f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ 'docstring-parser', 'get-annotations~=0.1;python_version<"3.10"', 'graphviz~=0.19', + 'plumpy', 'ipython>=7', 'jedi<0.19', 'jinja2~=3.0', @@ -35,7 +36,6 @@ dependencies = [ 'importlib-metadata~=6.0', 'numpy~=1.21', 'paramiko~=3.0', - 'plumpy~=0.24.0', 'pgsu~=0.3.0', 'psutil~=5.6', 'psycopg[binary]~=3.0', @@ -522,3 +522,6 @@ commands = molecule {posargs:test} # .github/actions/install-aiida-core/action.yml # .readthedocs.yml required-version = ">=0.5.21" + +[tool.uv.sources] +plumpy = {git = "https://github.com/aiidateam/plumpy.git", branch = "force-kill-v2"} diff --git a/src/aiida/cmdline/commands/cmd_process.py b/src/aiida/cmdline/commands/cmd_process.py index 5ad7c5d53c..f3710389f9 100644 --- a/src/aiida/cmdline/commands/cmd_process.py +++ b/src/aiida/cmdline/commands/cmd_process.py @@ -320,8 +320,12 @@ def process_status(call_link_label, most_recent_node, max_depth, processes): @options.ALL(help='Kill all processes if no specific processes are specified.') @options.TIMEOUT() @options.WAIT() +@options.FORCE_KILL( + help='Force kill the process if it does not respond to the initial kill signal.\n' + ' Note: This may lead to orphaned jobs on your HPC and should be used with caution.' +) @decorators.with_dbenv() -def process_kill(processes, all_entries, timeout, wait): +def process_kill(processes, all_entries, timeout, wait, force_kill): """Kill running processes. Kill one or multiple running processes.""" @@ -338,11 +342,17 @@ def process_kill(processes, all_entries, timeout, wait): if all_entries: click.confirm('Are you sure you want to kill all processes?', abort=True) + if force_kill: + echo.echo_warning('Force kill is enabled. This may lead to orphaned jobs on your HPC.') + msg_text = 'Force killed through `verdi process kill`' + else: + msg_text = 'Killed through `verdi process kill`' with capture_logging() as stream: try: control.kill_processes( processes, - msg_text='Killed through `verdi process kill`', + msg_text=msg_text, + force_kill=force_kill, all_entries=all_entries, timeout=timeout, wait=wait, diff --git a/src/aiida/cmdline/params/options/main.py b/src/aiida/cmdline/params/options/main.py index c2ce719375..ae3cd3a3a3 100644 --- a/src/aiida/cmdline/params/options/main.py +++ b/src/aiida/cmdline/params/options/main.py @@ -61,6 +61,7 @@ 'EXPORT_FORMAT', 'FAILED', 'FORCE', + 'FORCE_KILL', 'FORMULA_MODE', 'FREQUENCY', 'GROUP', @@ -329,6 +330,14 @@ def set_log_level(ctx, _param, value): FORCE = OverridableOption('-f', '--force', is_flag=True, default=False, help='Do not ask for confirmation.') +FORCE_KILL = OverridableOption( + '-F', + '--force-kill', + is_flag=True, + default=False, + help='Kills the process without waiting for a confirmation if the job has been killed from remote.', +) + SILENT = OverridableOption('-s', '--silent', is_flag=True, default=False, help='Suppress any output printed to stdout.') VISUALIZATION_FORMAT = OverridableOption( diff --git a/src/aiida/engine/daemon/client.py b/src/aiida/engine/daemon/client.py index ffdd83d30c..bd9ec5ecee 100644 --- a/src/aiida/engine/daemon/client.py +++ b/src/aiida/engine/daemon/client.py @@ -477,6 +477,10 @@ def get_worker_info(self, timeout: int | None = None) -> dict[str, t.Any]: command = {'command': 'stats', 'properties': {'name': self.daemon_name}} return self.call_client(command, timeout=timeout) + def get_number_of_workers(self, timeout: int | None = None) -> int: + """Get number of workers.""" + return len(self.get_worker_info(timeout).get('info', [])) + def get_daemon_info(self, timeout: int | None = None) -> dict[str, t.Any]: """Get statistics about this daemon itself. diff --git a/src/aiida/engine/daemon/worker.py b/src/aiida/engine/daemon/worker.py index 913e44d9b7..fccba72d63 100644 --- a/src/aiida/engine/daemon/worker.py +++ b/src/aiida/engine/daemon/worker.py @@ -37,13 +37,14 @@ async def shutdown_worker(runner: Runner) -> None: LOGGER.info('Daemon worker stopped') -def start_daemon_worker(foreground: bool = False) -> None: +def start_daemon_worker(foreground: bool = False, profile_name: str | None = None) -> None: """Start a daemon worker for the currently configured profile. :param foreground: If true, the logging will be configured to write to stdout, otherwise it will be configured to write to the daemon log file. """ - daemon_client = get_daemon_client() + + daemon_client = get_daemon_client(profile_name) configure_logging(with_orm=True, daemon=not foreground, daemon_log_file=daemon_client.daemon_log_file) LOGGER.debug(f'sys.executable: {sys.executable}') @@ -68,6 +69,7 @@ def start_daemon_worker(foreground: bool = False) -> None: try: LOGGER.info('Starting a daemon worker') + runner.start() except SystemError as exception: LOGGER.info('Received a SystemError: %s', exception) diff --git a/src/aiida/engine/processes/calcjobs/tasks.py b/src/aiida/engine/processes/calcjobs/tasks.py index 748f366a0f..0c1e99a0f7 100644 --- a/src/aiida/engine/processes/calcjobs/tasks.py +++ b/src/aiida/engine/processes/calcjobs/tasks.py @@ -27,7 +27,8 @@ from aiida.engine.daemon import execmanager from aiida.engine.processes.exit_code import ExitCode from aiida.engine.transports import TransportQueue -from aiida.engine.utils import InterruptableFuture, exponential_backoff_retry, interruptable_task +from aiida.engine import utils +from aiida.engine.utils import InterruptableFuture, interruptable_task from aiida.manage.configuration import get_config_option from aiida.orm.nodes.process.calculation.calcjob import CalcJobNode from aiida.schedulers.datastructures import JobState @@ -59,7 +60,7 @@ async def task_upload_job(process: 'CalcJob', transport_queue: TransportQueue, c """Transport task that will attempt to upload the files of a job calculation to the remote. The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager - function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will + function is called, wrapped in the utils.exponential_backoff_retry coroutine, which, in case of a caught exception, will retry after an interval that increases exponentially with the number of retries, for a maximum number of retries. If all retries fail, the task will raise a TransportTaskException @@ -102,7 +103,7 @@ async def do_upload(): try: logger.info(f'scheduled request to upload CalcJob<{node.pk}>') ignore_exceptions = (plumpy.futures.CancelledError, PreSubmitException, plumpy.process_states.Interruption) - skip_submit = await exponential_backoff_retry( + skip_submit = await utils.exponential_backoff_retry( do_upload, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions ) except PreSubmitException: @@ -122,7 +123,7 @@ async def task_submit_job(node: CalcJobNode, transport_queue: TransportQueue, ca """Transport task that will attempt to submit a job calculation. The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager - function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will + function is called, wrapped in the utils.exponential_backoff_retry coroutine, which, in case of a caught exception, will retry after an interval that increases exponentially with the number of retries, for a maximum number of retries. If all retries fail, the task will raise a TransportTaskException @@ -150,7 +151,7 @@ async def do_submit(): try: logger.info(f'scheduled request to submit CalcJob<{node.pk}>') ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption) - result = await exponential_backoff_retry( + result = await utils.exponential_backoff_retry( do_submit, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions ) except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): @@ -168,7 +169,7 @@ async def task_update_job(node: CalcJobNode, job_manager, cancellable: Interrupt """Transport task that will attempt to update the scheduler status of the job calculation. The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager - function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will + function is called, wrapped in the utils.exponential_backoff_retry coroutine, which, in case of a caught exception, will retry after an interval that increases exponentially with the number of retries, for a maximum number of retries. If all retries fail, the task will raise a TransportTaskException @@ -208,7 +209,7 @@ async def do_update(): try: logger.info(f'scheduled request to update CalcJob<{node.pk}>') ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption) - job_done = await exponential_backoff_retry( + job_done = await utils.exponential_backoff_retry( do_update, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions ) except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): @@ -230,7 +231,7 @@ async def task_monitor_job( """Transport task that will monitor the job calculation if any monitors have been defined. The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager - function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will + function is called, wrapped in the utils.exponential_backoff_retry coroutine, which, in case of a caught exception, will retry after an interval that increases exponentially with the number of retries, for a maximum number of retries. If all retries fail, the task will raise a TransportTaskException @@ -258,7 +259,7 @@ async def do_monitor(): try: logger.info(f'scheduled request to monitor CalcJob<{node.pk}>') ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption) - monitor_result = await exponential_backoff_retry( + monitor_result = await utils.exponential_backoff_retry( do_monitor, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions ) except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): @@ -279,7 +280,7 @@ async def task_retrieve_job( ): """Transport task that will attempt to retrieve all files of a completed job calculation. The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager - function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will + function is called, wrapped in the utils.exponential_backoff_retry coroutine, which, in case of a caught exception, will retry after an interval that increases exponentially with the number of retries, for a maximum number of retries. If all retries fail, the task will raise a TransportTaskException :param process: the job calculation @@ -326,7 +327,7 @@ async def do_retrieve(): try: logger.info(f'scheduled request to retrieve CalcJob<{node.pk}>') ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption) - result = await exponential_backoff_retry( + result = await utils.exponential_backoff_retry( do_retrieve, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions ) except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): @@ -344,7 +345,7 @@ async def task_stash_job(node: CalcJobNode, transport_queue: TransportQueue, can """Transport task that will optionally stash files of a completed job calculation on the remote. The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager - function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will + function is called, wrapped in the utils.exponential_backoff_retry coroutine, which, in case of a caught exception, will retry after an interval that increases exponentially with the number of retries, for a maximum number of retries. If all retries fail, the task will raise a TransportTaskException @@ -371,7 +372,7 @@ async def do_stash(): return await execmanager.stash_calculation(node, transport) try: - await exponential_backoff_retry( + await utils.exponential_backoff_retry( do_stash, initial_interval, max_attempts, @@ -389,7 +390,7 @@ async def do_stash(): return -async def task_kill_job(node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture): +async def task_kill_job(node: CalcJobNode, transport_queue: TransportQueue, force_kill: bool, cancellable: InterruptableFuture): """Transport task that will attempt to kill a job calculation. The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager @@ -403,8 +404,9 @@ async def task_kill_job(node: CalcJobNode, transport_queue: TransportQueue, canc :raises: TransportTaskException if after the maximum number of retries the transport task still excepted """ + breakpoint() initial_interval = get_config_option(RETRY_INTERVAL_OPTION) - max_attempts = get_config_option(MAX_ATTEMPTS_OPTION) + max_attempts = get_config_option(MAX_ATTEMPTS_OPTION)# if not force_kill else 1 if node.get_state() in [CalcJobState.UPLOADING, CalcJobState.SUBMITTING]: logger.warning(f'CalcJob<{node.pk}> killed, it was in the {node.get_state()} state') @@ -419,7 +421,7 @@ async def do_kill(): try: logger.info(f'scheduled request to kill CalcJob<{node.pk}>') - result = await exponential_backoff_retry(do_kill, initial_interval, max_attempts, logger=node.logger) + result = await utils.exponential_backoff_retry(do_kill, initial_interval, max_attempts, logger=node.logger) except plumpy.process_states.Interruption: raise except Exception as exception: @@ -558,7 +560,8 @@ async def execute(self) -> plumpy.process_states.State: # type: ignore[override except TransportTaskException as exception: raise plumpy.process_states.PauseInterruption(f'Pausing after failed transport task: {exception}') except plumpy.process_states.KillInterruption as exception: - await self._kill_job(node, transport_queue) + breakpoint() + await self._kill_job(node, transport_queue, exception.force_kill) node.set_process_status(str(exception)) return self.retrieve(monitor_result=self._monitor_result) except (plumpy.futures.CancelledError, asyncio.CancelledError): @@ -598,9 +601,10 @@ async def _monitor_job(self, node, transport_queue, monitors) -> CalcJobMonitorR return monitor_result - async def _kill_job(self, node, transport_queue) -> None: + async def _kill_job(self, node, transport_queue, force_kill: bool) -> None: """Kill the job.""" - await self._launch_task(task_kill_job, node, transport_queue) + breakpoint() + await self._launch_task(task_kill_job, node, transport_queue, force_kill) if self._killing is not None: self._killing.set_result(True) else: @@ -664,8 +668,10 @@ def parse( def interrupt(self, reason: Any) -> Optional[plumpy.futures.Future]: # type: ignore[override] """Interrupt the `Waiting` state by calling interrupt on the transport task `InterruptableFuture`.""" + breakpoint() if self._task is not None: self._task.interrupt(reason) + breakpoint() if isinstance(reason, plumpy.process_states.KillInterruption): if self._killing is None: diff --git a/src/aiida/engine/processes/control.py b/src/aiida/engine/processes/control.py index 2ecc8477df..f6b0172225 100644 --- a/src/aiida/engine/processes/control.py +++ b/src/aiida/engine/processes/control.py @@ -173,6 +173,7 @@ def kill_processes( processes: list[ProcessNode] | None = None, *, msg_text: str = 'Killed through `aiida.engine.processes.control.kill_processes`', + force_kill: bool = False, all_entries: bool = False, timeout: float = 5.0, wait: bool = False, @@ -201,7 +202,7 @@ def kill_processes( return controller = get_manager().get_process_controller() - action = functools.partial(controller.kill_process, msg_text=msg_text) + action = functools.partial(controller.kill_process, msg_text=msg_text, force_kill=force_kill) _perform_actions(processes, action, 'kill', 'killing', timeout, wait) @@ -282,7 +283,7 @@ def handle_result(result): try: # unwrap is need here since LoopCommunicator will also wrap a future unwrapped = unwrap_kiwi_future(future) - result = unwrapped.result() + result = unwrapped.result()#timeout) except communications.TimeoutError: LOGGER.error(f'call to {infinitive} Process<{process.pk}> timed out') except Exception as exception: diff --git a/src/aiida/engine/processes/process.py b/src/aiida/engine/processes/process.py index f29d426770..edbeca8704 100644 --- a/src/aiida/engine/processes/process.py +++ b/src/aiida/engine/processes/process.py @@ -329,7 +329,7 @@ def load_instance_state( self.node.logger.info(f'Loaded process<{self.node.pk}> from saved state') - def kill(self, msg_text: str | None = None) -> Union[bool, plumpy.futures.Future]: + def kill(self, msg_text: str | None = None, force_kill: bool = False) -> Union[bool, plumpy.futures.Future]: """Kill the process and all the children calculations it called :param msg: message @@ -338,7 +338,7 @@ def kill(self, msg_text: str | None = None) -> Union[bool, plumpy.futures.Future had_been_terminated = self.has_terminated() - result = super().kill(msg_text) + result = super().kill(msg_text, force_kill) # Only kill children if we could be killed ourselves if result is not False and not had_been_terminated: diff --git a/src/aiida/engine/utils.py b/src/aiida/engine/utils.py index 86517a5ada..6eb18a05c6 100644 --- a/src/aiida/engine/utils.py +++ b/src/aiida/engine/utils.py @@ -130,6 +130,7 @@ def interruptable_task( """ loop = loop or asyncio.get_event_loop() future = InterruptableFuture() + breakpoint() async def execute_coroutine(): """Coroutine that wraps the original coroutine and sets it result on the future only if not already set.""" @@ -193,6 +194,7 @@ async def exponential_backoff_retry( :param ignore_exceptions: exceptions to ignore, i.e. when caught do nothing and simply re-raise :return: result if the ``coro`` call completes within ``max_attempts`` retries without raising """ + #breakpoint() if logger is None: logger = LOGGER diff --git a/src/aiida/manage/tests/pytest_fixtures.py b/src/aiida/manage/tests/pytest_fixtures.py index 0482334864..5865c7eb94 100644 --- a/src/aiida/manage/tests/pytest_fixtures.py +++ b/src/aiida/manage/tests/pytest_fixtures.py @@ -760,6 +760,8 @@ def _factory( f'Daemon <{started_daemon_client.profile.name}|{daemon_status}> log file content: \n' f'{daemon_log_file}' ) + time.sleep(1) + print(node.process_state) return node diff --git a/src/aiida/tools/pytest_fixtures/daemon.py b/src/aiida/tools/pytest_fixtures/daemon.py index d838e05833..1a032b3fc6 100644 --- a/src/aiida/tools/pytest_fixtures/daemon.py +++ b/src/aiida/tools/pytest_fixtures/daemon.py @@ -155,6 +155,7 @@ def factory( f'Daemon <{started_daemon_client.profile.name}|{daemon_status}> log file content: \n' f'{daemon_log_file}' ) + time.sleep(0.1) return node diff --git a/tests/cmdline/commands/test_process.py b/tests/cmdline/commands/test_process.py index 15ff8911eb..638171a765 100644 --- a/tests/cmdline/commands/test_process.py +++ b/tests/cmdline/commands/test_process.py @@ -13,6 +13,8 @@ import time import typing as t import uuid +from contextlib import contextmanager +from pathlib import Path import pytest @@ -27,14 +29,211 @@ from tests.utils.processes import WaitProcess -def await_condition(condition: t.Callable, timeout: int = 1): +def start_daemon_worker_in_foreground_and_redirect_streams(aiida_profile, log_dir: Path): + """Starts a daemon worker and logs its stdout and and stderr streams to a file in the daemon log directory.""" + import os + import sys + + from aiida.engine.daemon.worker import start_daemon_worker + + original_stdout = sys.stdout + original_stderr = sys.stderr + + try: + pid = os.getpid() + sys.stdout = open(log_dir / f'worker-{pid}.out', 'w') + sys.stderr = open(log_dir / f'worker-{pid}.err', 'w') + start_daemon_worker(False, aiida_profile.name) + finally: + # TODO actually this now redirects it to regular stdout stderr which I think is better + if sys.stdout != original_stdout: + sys.stdout.close() + sys.stdout = original_stdout + if sys.stderr != original_stderr: + sys.stderr.close() + sys.stderr = original_stderr + + +@pytest.fixture() +@pytest.mark.usefixtures('started_daemon_client') +def fork_worker_context(aiida_profile): + """Runs daemon worker on a new process with redirected stdout and stderr streams.""" + import multiprocessing + + from aiida.engine.daemon.client import get_daemon_client + + client = get_daemon_client(aiida_profile) + nb_workers = client.get_number_of_workers() + client.decrease_workers(nb_workers) + daemon_log_dir = Path(client.daemon_log_file).parent + + @contextmanager + def fork_worker(): + ctx = multiprocessing.get_context('fork') + # we need to pass the aiida profile so it uses the same configuration + process = ctx.Process( + target=start_daemon_worker_in_foreground_and_redirect_streams, args=(aiida_profile, daemon_log_dir) + ) + process.start() + + yield process + + # TODO This should work according to start_daemon_worker code but it does not + # process.terminate() + process.kill() + process.join() + + yield fork_worker + + client.increase_workers(nb_workers) + + +def await_condition(condition: t.Callable, timeout: int = 1) -> t.Any: """Wait for the ``condition`` to evaluate to ``True`` within the ``timeout`` or raise.""" start_time = time.time() - while not condition(): + while not (result := condition()): if time.time() - start_time > timeout: raise RuntimeError(f'waiting for {condition} to evaluate to `True` timed out after {timeout} seconds.') + time.sleep(0.1) + + return result + + +# TODO this test fails if I run something daemon related before +@pytest.mark.requires_rmq +@pytest.mark.usefixtures('started_daemon_client') +def test_process_kill_failing_transport_force_kill( + fork_worker_context, submit_and_await, aiida_code_installed, run_cli_command, monkeypatch +): + """Tests if a process that is unable to open a transport connection can be force killed. + + A failure in opening a transport connection results in the EBM to be fired blocking a regular kill command. + The force kill command will ignore the EBM and kill the process in any case.""" + from aiida.cmdline.utils.common import get_process_function_report + from aiida.orm import Int + + code = aiida_code_installed(default_calc_job_plugin='core.arithmetic.add', filepath_executable='/bin/bash') + + def make_a_builder(sleep_seconds=0): + builder = code.get_builder() + builder.x = Int(1) + builder.y = Int(1) + builder.metadata.options.sleep = sleep_seconds + return builder + + kill_timeout = 5 + + # patch a faulty transport open + def mock_open(_): + raise Exception('Mock open exception') + + monkeypatch.setattr('aiida.transports.plugins.local.LocalTransport.open', mock_open) + + # We fork after the monkeypatching so the process inherits the changes + # 7) *Force* kill a process that has stuck in EBM, something that *kill* cannot do. + # `verdi process kill -F` --as the first attempt-- + with fork_worker_context(): + # TODO temporary here for debugging remove + # print(run_cli_command(cmd_process.process_list).stdout_bytes.decode()) + node = submit_and_await(make_a_builder(100), ProcessState.WAITING) + result = await_condition(lambda: get_process_function_report(node), timeout=kill_timeout) + assert 'Mock open exception' in result + assert 'exponential_backoff_retry' in result + + # force kill the process + run_cli_command(cmd_process.process_kill, [str(node.pk), '-F', '--wait']) + await_condition(lambda: node.is_killed, timeout=kill_timeout) + assert node.is_killed + assert node.process_status == 'Force killed through `verdi process kill`' + +@pytest.mark.requires_rmq +@pytest.mark.usefixtures('started_daemon_client') +def test_process_kill_failing_transport_failed_kill( + fork_worker_context, submit_and_await, aiida_code_installed, run_cli_command, monkeypatch +): + """Tests if a process that is unable to open a transport connection can be force killed. + + A failure in opening a transport connection results in the EBM to be fired blocking a regular kill command. + The force kill command will ignore the EBM and kill the process in any case.""" + from aiida.cmdline.utils.common import get_process_function_report + from aiida.orm import Int + + code = aiida_code_installed(default_calc_job_plugin='core.arithmetic.add', filepath_executable='/bin/bash') + + def make_a_builder(sleep_seconds=0): + builder = code.get_builder() + builder.x = Int(1) + builder.y = Int(1) + builder.metadata.options.sleep = sleep_seconds + return builder + + kill_timeout = 5 + + # patch a faulty transport open + def mock_open(_): + raise Exception('Mock open exception') + + monkeypatch.setattr('aiida.transports.plugins.local.LocalTransport.open', mock_open) + # 8) A process that has stuck in EBM, cannot get killed directly by `verdi process kill`. + # Such a process with a history of failed attempts, should still be able to get force killed. + # `verdi process kill -F` --as the second attempt-- + with fork_worker_context(): + node = submit_and_await(make_a_builder(5), ProcessState.WAITING) + + # assert the process is stuck in EBM + result = await_condition(lambda: get_process_function_report(node), timeout=kill_timeout) + assert 'Mock open exception' in result + assert 'exponential_backoff_retry' in result + + # practice a normal kill, which should fail + result = run_cli_command(cmd_process.process_kill, [str(node.pk), '--wait', '--timeout', '1.0']) + assert f'Error: call to kill Process<{node.pk}> timed out' in result.stdout + # force kill the process + result = run_cli_command(cmd_process.process_kill, [str(node.pk), '-F', '--wait']) + await_condition(lambda: node.is_killed, timeout=kill_timeout) + assert node.process_status == 'Force killed through `verdi process kill`' + + +@pytest.mark.requires_rmq +@pytest.mark.usefixtures('started_daemon_client') +def test_process_kill_failng_ebm( + fork_worker_context, submit_and_await, aiida_code_installed, run_cli_command, monkeypatch +): + """9) Kill a process that is paused after EBM (5 times failed). It should be possible to kill it normally. + # (e.g. in scenarios that transport is working again) + """ + from aiida.cmdline.utils.common import get_process_function_report + from aiida.orm import Int + + code = aiida_code_installed(default_calc_job_plugin='core.arithmetic.add', filepath_executable='/bin/bash') + + def make_a_builder(sleep_seconds=0): + builder = code.get_builder() + builder.x = Int(1) + builder.y = Int(1) + builder.metadata.options.sleep = sleep_seconds + return builder + + kill_timeout = 10 + + from aiida.common.exceptions import TransportTaskException + async def mock_exponential_backoff_retry(*_, **__): + raise TransportTaskException + + # patch EBM, to make it fail quickly. + monkeypatch.setattr('aiida.engine.utils.exponential_backoff_retry', mock_exponential_backoff_retry) + with fork_worker_context(): + node = submit_and_await(make_a_builder(), ProcessState.WAITING) + await_condition( + lambda: node.process_status + == 'Pausing after failed transport task: upload_calculation failed 5 times consecutively', + timeout=kill_timeout, + ) + + run_cli_command(cmd_process.process_kill, [str(node.pk), '--wait']) + await_condition(lambda: node.is_killed, timeout=kill_timeout) class TestVerdiProcess: """Tests for `verdi process`.""" @@ -537,9 +736,29 @@ def test_process_play_all(submit_and_await, run_cli_command): @pytest.mark.requires_rmq @pytest.mark.usefixtures('started_daemon_client') -def test_process_kill(submit_and_await, run_cli_command): - """Test the ``verdi process kill`` command.""" - node = submit_and_await(WaitProcess, ProcessState.WAITING) +def test_process_kill_uni(submit_and_await, run_cli_command, aiida_code_installed): + """Test the ``verdi process kill`` command. + It tries to cover all the possible scenarios of killing a process. + """ + + kill_timeout = 5 + + # 0) Running without identifiers should except and print something + result = run_cli_command(cmd_process.process_kill, raises=True) + assert result.exit_code == ExitCode.USAGE_ERROR + assert len(result.output_lines) > 0 + + + from aiida.calculations.arithmetic.add import ArithmeticAddCalculation + from aiida.orm import Int + code = aiida_code_installed(default_calc_job_plugin='core.arithmetic.add', filepath_executable='/bin/bash') + builder = code.get_builder() + builder.x = Int(2) + builder.y = Int(3) + builder.metadata.options.sleep = 10 + + # 1) Kill a paused process + node = submit_and_await(builder, ProcessState.WAITING) run_cli_command(cmd_process.process_pause, [str(node.pk), '--wait']) await_condition(lambda: node.paused) @@ -549,11 +768,41 @@ def test_process_kill(submit_and_await, run_cli_command): await_condition(lambda: node.is_killed) assert node.process_status == 'Killed through `verdi process kill`' - # Running without identifiers should except and print something - options = [] - result = run_cli_command(cmd_process.process_kill, options, raises=True) - assert result.exit_code == ExitCode.USAGE_ERROR - assert len(result.output_lines) > 0 + # 2) Force kill a paused process + node = submit_and_await(builder, ProcessState.WAITING) + + run_cli_command(cmd_process.process_pause, [str(node.pk), '--wait']) + await_condition(lambda: node.paused) + assert node.process_status == 'Paused through `verdi process pause`' + + run_cli_command(cmd_process.process_kill, [str(node.pk), '-F', '--wait']) + await_condition(lambda: node.is_killed) + assert node.process_status == 'Force killed through `verdi process kill`' + + # TODO test takes very long + # 5) `verdi process kill --all` should kill all processes + node_1 = submit_and_await(builder, ProcessState.WAITING) + run_cli_command(cmd_process.process_pause, [str(node_1.pk), '--wait']) + await_condition(lambda: node_1.paused) + node_2 = submit_and_await(builder, ProcessState.WAITING) + + run_cli_command(cmd_process.process_kill, ['--all', '--wait'], user_input='y') + await_condition(lambda: node_1.is_killed, timeout=kill_timeout) + await_condition(lambda: node_2.is_killed, timeout=kill_timeout) + assert node_1.process_status == 'Killed through `verdi process kill`' + assert node_2.process_status == 'Killed through `verdi process kill`' + + # 6) `verdi process kill --all -F` should Force kill all processes (running / not running) + node_1 = submit_and_await(builder, ProcessState.WAITING) + run_cli_command(cmd_process.process_pause, [str(node_1.pk), '--wait']) + await_condition(lambda: node_1.paused) + node_2 = submit_and_await(builder, ProcessState.WAITING) + + run_cli_command(cmd_process.process_kill, ['--all', '--wait', '-F'], user_input='y') + await_condition(lambda: node_1.is_killed, timeout=kill_timeout) + await_condition(lambda: node_2.is_killed, timeout=kill_timeout) + assert node_1.process_status == 'Force killed through `verdi process kill`' + assert node_2.process_status == 'Force killed through `verdi process kill`' @pytest.mark.requires_rmq diff --git a/tests/utils/processes.py b/tests/utils/processes.py index 170f313211..db91fa8e3e 100644 --- a/tests/utils/processes.py +++ b/tests/utils/processes.py @@ -81,6 +81,17 @@ def next_step(self): pass +class RunningProcess(Process): + _node_class = WorkflowNode + + async def run(self): + import asyncio + + await asyncio.sleep(100) + # TODO this does not work but at the moment never reaches there + return plumpy.Continue(self.run) + + class InvalidateCaching(Process): """A process which invalidates cache for some exit codes.""" diff --git a/uv.lock b/uv.lock index 93b34c4cc5..f1ed338b7e 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.9" resolution-markers = [ "python_full_version >= '3.12' and sys_platform == 'win32'", @@ -203,7 +204,7 @@ requires-dist = [ { name = "pg8000", marker = "extra == 'tests'", specifier = "~=1.13" }, { name = "pgsu", specifier = "~=0.3.0" }, { name = "pgtest", marker = "extra == 'tests'", specifier = "~=1.3,>=1.3.1" }, - { name = "plumpy", specifier = "~=0.24.0" }, + { name = "plumpy", git = "https://github.com/aiidateam/plumpy.git?branch=force-kill-v2" }, { name = "pre-commit", marker = "extra == 'pre-commit'", specifier = "~=3.5" }, { name = "psutil", specifier = "~=5.6" }, { name = "psycopg", extras = ["binary"], specifier = "~=3.0" }, @@ -249,6 +250,7 @@ requires-dist = [ { name = "upf-to-json", specifier = "~=0.9.2" }, { name = "wrapt", specifier = "~=1.11" }, ] +provides-extras = ["atomic-tools", "bpython", "docs", "notebook", "pre-commit", "rest", "ssh-kerberos", "tests", "tui"] [[package]] name = "aiida-export-migration-tests" @@ -3267,16 +3269,12 @@ wheels = [ [[package]] name = "plumpy" version = "0.24.0" -source = { registry = "https://pypi.org/simple" } +source = { git = "https://github.com/aiidateam/plumpy.git?branch=force-kill-v2#f6b466d0b3a3eb79213e3372cc88e5f2d5c4ab68" } dependencies = [ { name = "kiwipy", extra = ["rmq"] }, { name = "nest-asyncio" }, { name = "pyyaml" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d9/0c/0bb568982e461f5e428606ccbdfe6d43c11dab0e3f5a8090298feb321172/plumpy-0.24.0.tar.gz", hash = "sha256:c17c8efbd124d7f5ec2f27cb1f2c3de7901143e61551ce81f3ee22bf7e2ed42d", size = 75634 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/99/d3/68c83d4774f7a4f8e8dd4e30ce34e46071706a4b4dc40d3a1ad77de793fc/plumpy-0.24.0-py3-none-any.whl", hash = "sha256:09efafe97c88c8928e73f1dc08cf02a2c4737fa767920bff23dfa26226252cc6", size = 74955 }, -] [[package]] name = "ply"