Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ dependencies = [
'docstring-parser',
'get-annotations~=0.1;python_version<"3.10"',
'graphviz~=0.19',
'plumpy',
'ipython>=7',
'jedi<0.19',
'jinja2~=3.0',
'kiwipy[rmq]~=0.8.4',
'importlib-metadata~=6.0',
'numpy~=1.21',
'paramiko~=3.0',
'plumpy~=0.24.0',
'pgsu~=0.3.0',
'psutil~=5.6',
'psycopg[binary]~=3.0',
Expand Down Expand Up @@ -522,3 +522,6 @@ commands = molecule {posargs:test}
# .github/actions/install-aiida-core/action.yml
# .readthedocs.yml
required-version = ">=0.5.21"

[tool.uv.sources]
plumpy = {git = "https://github.com/aiidateam/plumpy.git", branch = "force-kill-v2"}
14 changes: 12 additions & 2 deletions src/aiida/cmdline/commands/cmd_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,12 @@ def process_status(call_link_label, most_recent_node, max_depth, processes):
@options.ALL(help='Kill all processes if no specific processes are specified.')
@options.TIMEOUT()
@options.WAIT()
@options.FORCE_KILL(
help='Force kill the process if it does not respond to the initial kill signal.\n'
' Note: This may lead to orphaned jobs on your HPC and should be used with caution.'
)
@decorators.with_dbenv()
def process_kill(processes, all_entries, timeout, wait):
def process_kill(processes, all_entries, timeout, wait, force_kill):
"""Kill running processes.

Kill one or multiple running processes."""
Expand All @@ -338,11 +342,17 @@ def process_kill(processes, all_entries, timeout, wait):
if all_entries:
click.confirm('Are you sure you want to kill all processes?', abort=True)

if force_kill:
echo.echo_warning('Force kill is enabled. This may lead to orphaned jobs on your HPC.')
msg_text = 'Force killed through `verdi process kill`'
else:
msg_text = 'Killed through `verdi process kill`'
with capture_logging() as stream:
try:
control.kill_processes(
processes,
msg_text='Killed through `verdi process kill`',
msg_text=msg_text,
force_kill=force_kill,
all_entries=all_entries,
timeout=timeout,
wait=wait,
Expand Down
9 changes: 9 additions & 0 deletions src/aiida/cmdline/params/options/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
'EXPORT_FORMAT',
'FAILED',
'FORCE',
'FORCE_KILL',
'FORMULA_MODE',
'FREQUENCY',
'GROUP',
Expand Down Expand Up @@ -329,6 +330,14 @@ def set_log_level(ctx, _param, value):

FORCE = OverridableOption('-f', '--force', is_flag=True, default=False, help='Do not ask for confirmation.')

FORCE_KILL = OverridableOption(
'-F',
'--force-kill',
is_flag=True,
default=False,
help='Kills the process without waiting for a confirmation if the job has been killed from remote.',
)

SILENT = OverridableOption('-s', '--silent', is_flag=True, default=False, help='Suppress any output printed to stdout.')

VISUALIZATION_FORMAT = OverridableOption(
Expand Down
4 changes: 4 additions & 0 deletions src/aiida/engine/daemon/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,10 @@ def get_worker_info(self, timeout: int | None = None) -> dict[str, t.Any]:
command = {'command': 'stats', 'properties': {'name': self.daemon_name}}
return self.call_client(command, timeout=timeout)

def get_number_of_workers(self, timeout: int | None = None) -> int:
"""Get number of workers."""
return len(self.get_worker_info(timeout).get('info', []))

def get_daemon_info(self, timeout: int | None = None) -> dict[str, t.Any]:
"""Get statistics about this daemon itself.

Expand Down
6 changes: 4 additions & 2 deletions src/aiida/engine/daemon/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ async def shutdown_worker(runner: Runner) -> None:
LOGGER.info('Daemon worker stopped')


def start_daemon_worker(foreground: bool = False) -> None:
def start_daemon_worker(foreground: bool = False, profile_name: str | None = None) -> None:
"""Start a daemon worker for the currently configured profile.

:param foreground: If true, the logging will be configured to write to stdout, otherwise it will be configured to
write to the daemon log file.
"""
daemon_client = get_daemon_client()

daemon_client = get_daemon_client(profile_name)
configure_logging(with_orm=True, daemon=not foreground, daemon_log_file=daemon_client.daemon_log_file)

LOGGER.debug(f'sys.executable: {sys.executable}')
Expand All @@ -68,6 +69,7 @@ def start_daemon_worker(foreground: bool = False) -> None:

try:
LOGGER.info('Starting a daemon worker')

runner.start()
except SystemError as exception:
LOGGER.info('Received a SystemError: %s', exception)
Expand Down
44 changes: 25 additions & 19 deletions src/aiida/engine/processes/calcjobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from aiida.engine.daemon import execmanager
from aiida.engine.processes.exit_code import ExitCode
from aiida.engine.transports import TransportQueue
from aiida.engine.utils import InterruptableFuture, exponential_backoff_retry, interruptable_task
from aiida.engine import utils
from aiida.engine.utils import InterruptableFuture, interruptable_task
from aiida.manage.configuration import get_config_option
from aiida.orm.nodes.process.calculation.calcjob import CalcJobNode
from aiida.schedulers.datastructures import JobState
Expand Down Expand Up @@ -59,7 +60,7 @@ async def task_upload_job(process: 'CalcJob', transport_queue: TransportQueue, c
"""Transport task that will attempt to upload the files of a job calculation to the remote.

The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager
function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will
function is called, wrapped in the utils.exponential_backoff_retry coroutine, which, in case of a caught exception, will
retry after an interval that increases exponentially with the number of retries, for a maximum number of retries.
If all retries fail, the task will raise a TransportTaskException

Expand Down Expand Up @@ -102,7 +103,7 @@ async def do_upload():
try:
logger.info(f'scheduled request to upload CalcJob<{node.pk}>')
ignore_exceptions = (plumpy.futures.CancelledError, PreSubmitException, plumpy.process_states.Interruption)
skip_submit = await exponential_backoff_retry(
skip_submit = await utils.exponential_backoff_retry(
do_upload, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions
)
except PreSubmitException:
Expand All @@ -122,7 +123,7 @@ async def task_submit_job(node: CalcJobNode, transport_queue: TransportQueue, ca
"""Transport task that will attempt to submit a job calculation.

The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager
function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will
function is called, wrapped in the utils.exponential_backoff_retry coroutine, which, in case of a caught exception, will
retry after an interval that increases exponentially with the number of retries, for a maximum number of retries.
If all retries fail, the task will raise a TransportTaskException

Expand Down Expand Up @@ -150,7 +151,7 @@ async def do_submit():
try:
logger.info(f'scheduled request to submit CalcJob<{node.pk}>')
ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption)
result = await exponential_backoff_retry(
result = await utils.exponential_backoff_retry(
do_submit, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions
)
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption):
Expand All @@ -168,7 +169,7 @@ async def task_update_job(node: CalcJobNode, job_manager, cancellable: Interrupt
"""Transport task that will attempt to update the scheduler status of the job calculation.

The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager
function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will
function is called, wrapped in the utils.exponential_backoff_retry coroutine, which, in case of a caught exception, will
retry after an interval that increases exponentially with the number of retries, for a maximum number of retries.
If all retries fail, the task will raise a TransportTaskException

Expand Down Expand Up @@ -208,7 +209,7 @@ async def do_update():
try:
logger.info(f'scheduled request to update CalcJob<{node.pk}>')
ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption)
job_done = await exponential_backoff_retry(
job_done = await utils.exponential_backoff_retry(
do_update, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions
)
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption):
Expand All @@ -230,7 +231,7 @@ async def task_monitor_job(
"""Transport task that will monitor the job calculation if any monitors have been defined.

The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager
function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will
function is called, wrapped in the utils.exponential_backoff_retry coroutine, which, in case of a caught exception, will
retry after an interval that increases exponentially with the number of retries, for a maximum number of retries.
If all retries fail, the task will raise a TransportTaskException

Expand Down Expand Up @@ -258,7 +259,7 @@ async def do_monitor():
try:
logger.info(f'scheduled request to monitor CalcJob<{node.pk}>')
ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption)
monitor_result = await exponential_backoff_retry(
monitor_result = await utils.exponential_backoff_retry(
do_monitor, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions
)
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption):
Expand All @@ -279,7 +280,7 @@ async def task_retrieve_job(
):
"""Transport task that will attempt to retrieve all files of a completed job calculation.
The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager
function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will
function is called, wrapped in the utils.exponential_backoff_retry coroutine, which, in case of a caught exception, will
retry after an interval that increases exponentially with the number of retries, for a maximum number of retries.
If all retries fail, the task will raise a TransportTaskException
:param process: the job calculation
Expand Down Expand Up @@ -326,7 +327,7 @@ async def do_retrieve():
try:
logger.info(f'scheduled request to retrieve CalcJob<{node.pk}>')
ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption)
result = await exponential_backoff_retry(
result = await utils.exponential_backoff_retry(
do_retrieve, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions
)
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption):
Expand All @@ -344,7 +345,7 @@ async def task_stash_job(node: CalcJobNode, transport_queue: TransportQueue, can
"""Transport task that will optionally stash files of a completed job calculation on the remote.

The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager
function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will
function is called, wrapped in the utils.exponential_backoff_retry coroutine, which, in case of a caught exception, will
retry after an interval that increases exponentially with the number of retries, for a maximum number of retries.
If all retries fail, the task will raise a TransportTaskException

Expand All @@ -371,7 +372,7 @@ async def do_stash():
return await execmanager.stash_calculation(node, transport)

try:
await exponential_backoff_retry(
await utils.exponential_backoff_retry(
do_stash,
initial_interval,
max_attempts,
Expand All @@ -389,7 +390,7 @@ async def do_stash():
return


async def task_kill_job(node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture):
async def task_kill_job(node: CalcJobNode, transport_queue: TransportQueue, force_kill: bool, cancellable: InterruptableFuture):
"""Transport task that will attempt to kill a job calculation.

The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager
Expand All @@ -403,8 +404,9 @@ async def task_kill_job(node: CalcJobNode, transport_queue: TransportQueue, canc

:raises: TransportTaskException if after the maximum number of retries the transport task still excepted
"""
breakpoint()
initial_interval = get_config_option(RETRY_INTERVAL_OPTION)
max_attempts = get_config_option(MAX_ATTEMPTS_OPTION)
max_attempts = get_config_option(MAX_ATTEMPTS_OPTION)# if not force_kill else 1

if node.get_state() in [CalcJobState.UPLOADING, CalcJobState.SUBMITTING]:
logger.warning(f'CalcJob<{node.pk}> killed, it was in the {node.get_state()} state')
Expand All @@ -419,7 +421,7 @@ async def do_kill():

try:
logger.info(f'scheduled request to kill CalcJob<{node.pk}>')
result = await exponential_backoff_retry(do_kill, initial_interval, max_attempts, logger=node.logger)
result = await utils.exponential_backoff_retry(do_kill, initial_interval, max_attempts, logger=node.logger)
except plumpy.process_states.Interruption:
raise
except Exception as exception:
Expand Down Expand Up @@ -558,7 +560,8 @@ async def execute(self) -> plumpy.process_states.State: # type: ignore[override
except TransportTaskException as exception:
raise plumpy.process_states.PauseInterruption(f'Pausing after failed transport task: {exception}')
except plumpy.process_states.KillInterruption as exception:
await self._kill_job(node, transport_queue)
breakpoint()
await self._kill_job(node, transport_queue, exception.force_kill)
node.set_process_status(str(exception))
return self.retrieve(monitor_result=self._monitor_result)
except (plumpy.futures.CancelledError, asyncio.CancelledError):
Expand Down Expand Up @@ -598,9 +601,10 @@ async def _monitor_job(self, node, transport_queue, monitors) -> CalcJobMonitorR

return monitor_result

async def _kill_job(self, node, transport_queue) -> None:
async def _kill_job(self, node, transport_queue, force_kill: bool) -> None:
"""Kill the job."""
await self._launch_task(task_kill_job, node, transport_queue)
breakpoint()
await self._launch_task(task_kill_job, node, transport_queue, force_kill)
if self._killing is not None:
self._killing.set_result(True)
else:
Expand Down Expand Up @@ -664,8 +668,10 @@ def parse(

def interrupt(self, reason: Any) -> Optional[plumpy.futures.Future]: # type: ignore[override]
"""Interrupt the `Waiting` state by calling interrupt on the transport task `InterruptableFuture`."""
breakpoint()
if self._task is not None:
self._task.interrupt(reason)
breakpoint()

if isinstance(reason, plumpy.process_states.KillInterruption):
if self._killing is None:
Expand Down
5 changes: 3 additions & 2 deletions src/aiida/engine/processes/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def kill_processes(
processes: list[ProcessNode] | None = None,
*,
msg_text: str = 'Killed through `aiida.engine.processes.control.kill_processes`',
force_kill: bool = False,
all_entries: bool = False,
timeout: float = 5.0,
wait: bool = False,
Expand Down Expand Up @@ -201,7 +202,7 @@ def kill_processes(
return

controller = get_manager().get_process_controller()
action = functools.partial(controller.kill_process, msg_text=msg_text)
action = functools.partial(controller.kill_process, msg_text=msg_text, force_kill=force_kill)
_perform_actions(processes, action, 'kill', 'killing', timeout, wait)


Expand Down Expand Up @@ -282,7 +283,7 @@ def handle_result(result):
try:
# unwrap is need here since LoopCommunicator will also wrap a future
unwrapped = unwrap_kiwi_future(future)
result = unwrapped.result()
result = unwrapped.result()#timeout)
except communications.TimeoutError:
LOGGER.error(f'call to {infinitive} Process<{process.pk}> timed out')
except Exception as exception:
Expand Down
4 changes: 2 additions & 2 deletions src/aiida/engine/processes/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def load_instance_state(

self.node.logger.info(f'Loaded process<{self.node.pk}> from saved state')

def kill(self, msg_text: str | None = None) -> Union[bool, plumpy.futures.Future]:
def kill(self, msg_text: str | None = None, force_kill: bool = False) -> Union[bool, plumpy.futures.Future]:
"""Kill the process and all the children calculations it called

:param msg: message
Expand All @@ -338,7 +338,7 @@ def kill(self, msg_text: str | None = None) -> Union[bool, plumpy.futures.Future

had_been_terminated = self.has_terminated()

result = super().kill(msg_text)
result = super().kill(msg_text, force_kill)

# Only kill children if we could be killed ourselves
if result is not False and not had_been_terminated:
Expand Down
2 changes: 2 additions & 0 deletions src/aiida/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def interruptable_task(
"""
loop = loop or asyncio.get_event_loop()
future = InterruptableFuture()
breakpoint()

async def execute_coroutine():
"""Coroutine that wraps the original coroutine and sets it result on the future only if not already set."""
Expand Down Expand Up @@ -193,6 +194,7 @@ async def exponential_backoff_retry(
:param ignore_exceptions: exceptions to ignore, i.e. when caught do nothing and simply re-raise
:return: result if the ``coro`` call completes within ``max_attempts`` retries without raising
"""
#breakpoint()
if logger is None:
logger = LOGGER

Expand Down
2 changes: 2 additions & 0 deletions src/aiida/manage/tests/pytest_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,8 @@ def _factory(
f'Daemon <{started_daemon_client.profile.name}|{daemon_status}> log file content: \n'
f'{daemon_log_file}'
)
time.sleep(1)
print(node.process_state)

return node

Expand Down
1 change: 1 addition & 0 deletions src/aiida/tools/pytest_fixtures/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def factory(
f'Daemon <{started_daemon_client.profile.name}|{daemon_status}> log file content: \n'
f'{daemon_log_file}'
)
time.sleep(0.1)

return node

Expand Down
Loading
Loading