Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions src/aiida/engine/processes/calcjobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ async def execute(self) -> plumpy.process_states.State: # type: ignore[override
monitor_result = await self._monitor_job(node, transport_queue, self.monitors)

if monitor_result and monitor_result.action is CalcJobMonitorAction.KILL:
await self._kill_job(node, transport_queue)
await self.kill_job(node, transport_queue)
job_done = True

if monitor_result and not monitor_result.retrieve:
Expand Down Expand Up @@ -582,7 +582,6 @@ async def execute(self) -> plumpy.process_states.State: # type: ignore[override
except TransportTaskException as exception:
raise plumpy.process_states.PauseInterruption(f'Pausing after failed transport task: {exception}')
except plumpy.process_states.KillInterruption as exception:
await self._kill_job(node, transport_queue)
node.set_process_status(str(exception))
return self.retrieve(monitor_result=self._monitor_result)
except (plumpy.futures.CancelledError, asyncio.CancelledError):
Expand All @@ -594,10 +593,13 @@ async def execute(self) -> plumpy.process_states.State: # type: ignore[override
else:
node.set_process_status(None)
return result
finally:
# If we were trying to kill but we didn't deal with it, make sure it's set here
if self._killing and not self._killing.done():
self._killing.set_result(False)
# PR_COMMENT We do not use the KillInterruption anymore to kill the job here as we kill the job where the KillInterruption is sent
# TODO remove
# finally:
# # If we were trying to kill but we didn't deal with it, make sure it's set here
# #if self._killing and not self._killing.done():
# # self._killing.set_result(False)
# pass

async def _monitor_job(self, node, transport_queue, monitors) -> CalcJobMonitorResult | None:
"""Process job monitors if any were specified as inputs."""
Expand All @@ -622,7 +624,7 @@ async def _monitor_job(self, node, transport_queue, monitors) -> CalcJobMonitorR

return monitor_result

async def _kill_job(self, node, transport_queue) -> None:
async def kill_job(self, node, transport_queue) -> None:
"""Kill the job."""
await self._launch_task(task_kill_job, node, transport_queue)
if self._killing is not None:
Expand Down
185 changes: 149 additions & 36 deletions src/aiida/engine/processes/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from aiida.common.links import LinkType
from aiida.common.log import LOG_LEVEL_REPORT
from aiida.orm.implementation.utils import clean_value
from aiida.orm.nodes.process.calculation.calcjob import CalcJobNode
from aiida.orm.utils import serialize

from .builder import ProcessBuilder
Expand Down Expand Up @@ -329,50 +330,162 @@ def load_instance_state(

self.node.logger.info(f'Loaded process<{self.node.pk}> from saved state')

async def _launch_task(self, coro, *args, **kwargs):
"""Launch a coroutine as a task, making sure to make it interruptable."""
import functools

from aiida.engine.utils import interruptable_task

task_fn = functools.partial(coro, *args, **kwargs)
try:
self._task = interruptable_task(task_fn)
result = await self._task
return result
finally:
self._task = None

def kill(self, msg_text: str | None = None, force_kill: bool = False) -> Union[bool, plumpy.futures.Future]:
"""Kill the process and all the children calculations it called

:param msg: message
"""
self.node.logger.info(f'Request to kill Process<{self.node.pk}>')

had_been_terminated = self.has_terminated()

result = super().kill(msg_text, force_kill)
# breakpoint()
if self.killed():
self.node.logger.info(f'Request to kill Process<{self.node.pk}> but process has already been killed.')
return True
elif self.has_terminated():
self.node.logger.info(f'Request to kill Process<{self.node.pk}> but process has already terminated.')
return False
self.node.logger.info(f'Request to kill Process<{self.node.pk}>.')

# PR_COMMENT We need to kill the children now before because we transition to kill after the first kill
# This became buggy in the last PR by allowing the user to reusing killing commands (if _killing do
# nothing). Since we want to now allow the user to resend killing commands with different options we
# have to kill first the children, or we still kill the children even when this process has been
# killed. Otherwise you have the problematic scenario: Process is killed but did not kill the
# children yet, kill timeouts, we kill again, but the parent process is already killed so it will
# never enter this code
#
# TODO if tests just pass it could mean that this is not well tested, need to check if there is a test

# TODO
# this blocks worker and it cannot be unblocked
# need async await maybe

killing = []
# breakpoint()
for child in self.node.called:
if self.runner.controller is None:
self.logger.info('no controller available to kill child<%s>', child.pk)
continue
try:
# we block for sending message

# Only kill children if we could be killed ourselves
if result is not False and not had_been_terminated:
killing = []
for child in self.node.called:
if self.runner.controller is None:
self.logger.info('no controller available to kill child<%s>', child.pk)
continue
try:
result = self.runner.controller.kill_process(child.pk, msg_text=f'Killed by parent<{self.node.pk}>')
result = asyncio.wrap_future(result) # type: ignore[arg-type]
if asyncio.isfuture(result):
killing.append(result)
except ConnectionClosed:
self.logger.info('no connection available to kill child<%s>', child.pk)
except UnroutableError:
self.logger.info('kill signal was unable to reach child<%s>', child.pk)

if asyncio.isfuture(result):
# We ourselves are waiting to be killed so add it to the list
killing.append(result)

if killing:
# result = self.loop.run_until_complete(coro)
# breakpoint()
result = self.runner.controller.kill_process(
child.pk, msg_text=f'Killed by parent<{self.node.pk}>', force_kill=force_kill
)
from plumpy.futures import unwrap_kiwi_future

killing.append(unwrap_kiwi_future(result))
breakpoint()
# result = unwrapped_future.result(timeout=5)
# result = asyncio.wrap_future(result) # type: ignore[arg-type]
# PR_COMMENT I commented out, we wrap it before to an asyncio future why the if check?
# if asyncio.isfuture(result):
# killing.append(result)
except ConnectionClosed:
self.logger.info('no connection available to kill child<%s>', child.pk)
except UnroutableError:
self.logger.info('kill signal was unable to reach child<%s>', child.pk)

# TODO need to check this part, might be overengineered
# if asyncio.isfuture(result):
# # We ourselves are waiting to be killed so add it to the list
# killing.append(result)

####### KILL TWO
if not force_kill:
# asyncio.send(continue_kill)
# return
for pending_future in killing:
# breakpoint()
result = pending_future.result()
# We are waiting for things to be killed, so return the 'gathered' future
kill_future = plumpy.futures.gather(*killing)
result = self.loop.create_future()

def done(done_future: plumpy.futures.Future):
is_all_killed = all(done_future.result())
result.set_result(is_all_killed)

kill_future.add_done_callback(done)

return result
# kill_future = plumpy.futures.gather(*killing)
# result = self.loop.create_future()
# breakpoint()

# def done(done_future: plumpy.futures.Future):
# is_all_killed = all(done_future.result())
# result.set_result(is_all_killed)

# kill_future.add_done_callback(done)

# PR_COMMENT We do not do this anymore. The original idea was to resend the killing interruption so the state
# can continue freeing its resources using an EBM with new parameters as the user can change these
# between kills by changing the config parameters. However this was not working properly because the
# process state goes only the first time it receives a KillInterruption into the EBM. This is because
# the EBM is activated within try-catch block.
# try:
# do_work() # <-- now we send the interrupt exception
# except KillInterruption:
# cancel_scheduler_job_in_ebm # <-- if we cancel it will just stop this
#
# Not sure why I did not detect this during my tries. We could also do a while loop of interrupts
# but I think it is generally not good design that the process state cancels the scheduler job while
# here we kill the process. It adds another actor responsible for killing the process correctly
# making it more complex than necessary.
#
# Cancel any old killing command to send a new one
# if self._killing:
# self._killing.cancel()

# Send kill interruption to the tasks in the event loop so they stop
# This is not blocking, so the interruption is happening concurrently
if self._stepping:
# Ask the step function to pause by setting this flag and giving the
# caller back a future
interrupt_exception = plumpy.process_states.KillInterruption(msg_text, force_kill)
# PR COMMENT we do not set interrupt action because plumpy is very smart it uses the interrupt action to set
# next state in the stepping, but we do not want to step to the next state through the plumpy
# state machine, we want to control this here and only here
# self._set_interrupt_action_from_exception(interrupt_exception)
# self._killing = self._interrupt_action
self._state.interrupt(interrupt_exception)
# return cast(plumpy.futures.CancellableAction, self._interrupt_action)

# Kill jobs from scheduler associated with this process.
# This is blocking so we only continue when the scheduler job has been killed.
if not force_kill and isinstance(self.node, CalcJobNode):
# TODO put this function into more common place
from .calcjobs.tasks import task_kill_job

# if already killing we have triggered the Interruption
coro = self._launch_task(task_kill_job, self.node, self.runner.transport)
task = asyncio.create_task(coro)
# task_kill_job is raising an error if not successful, e.g. EBM fails.
# PR COMMENT we just return False and write why the kill fails, it does not make sense to me to put the
# process to excepted. Maybe you fix your internet connection and want to try it again.
# We have force-kill now if the user wants to enforce a killing
try:
# breakpoint()
self.loop.run_until_complete(task)
# breakpoint()
except Exception as exc:
self.node.logger.error(f'While cancelling job error was raised: {exc!s}')
# breakpoint()
return False

# Transition to killed process state
# This is blocking so we only continue when we are in killed state
msg = plumpy.process_comms.MessageBuilder.kill(text=msg_text, force_kill=force_kill)
new_state = self._create_state_instance(plumpy.process_states.ProcessState.KILLED, msg=msg)
self.transition_to(new_state)

return True

@override
def out(self, output_port: str, value: Any = None) -> None:
Expand Down
34 changes: 34 additions & 0 deletions src/aiida/workchain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# TODO this class needs to be removed

from aiida.engine import ToContext, WorkChain
from aiida.orm import Bool


class MainWorkChain(WorkChain):
@classmethod
def define(cls, spec):
super().define(spec)
spec.input('kill', default=lambda: Bool(False))
spec.outline(cls.submit_child, cls.check)

def submit_child(self):
return ToContext(child=self.submit(SubWorkChain, kill=self.inputs.kill))

def check(self):
raise RuntimeError('should have been aborted by now')


class SubWorkChain(WorkChain):
@classmethod
def define(cls, spec):
super().define(spec)
spec.input('kill', default=lambda: Bool(False))
spec.outline(cls.begin, cls.check)

def begin(self):
"""If the Main should be killed, pause the child to give the Main a chance to call kill on its children"""
if self.inputs.kill:
self.pause()

def check(self):
raise RuntimeError('should have been aborted by now')
22 changes: 8 additions & 14 deletions tests/cmdline/commands/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import functools
import re
import time
import typing as t
import uuid
from contextlib import contextmanager
Expand All @@ -26,6 +25,7 @@
from aiida.engine import Process, ProcessState
from aiida.engine.processes import control as process_control
from aiida.orm import CalcJobNode, Group, WorkChainNode, WorkflowNode, WorkFunctionNode
from tests.conftest import await_condition
from tests.utils.processes import WaitProcess

FuncArgs = tuple[t.Any, ...]
Expand Down Expand Up @@ -116,18 +116,6 @@ def fork_worker(func, func_args):
client.increase_workers(nb_workers)


def await_condition(condition: t.Callable, timeout: int = 1) -> t.Any:
"""Wait for the ``condition`` to evaluate to ``True`` within the ``timeout`` or raise."""
start_time = time.time()

while not (result := condition()):
if time.time() - start_time > timeout:
raise RuntimeError(f'waiting for {condition} to evaluate to `True` timed out after {timeout} seconds.')
time.sleep(0.1)

return result


@pytest.mark.requires_rmq
@pytest.mark.usefixtures('started_daemon_client')
def test_process_kill_failing_transport(
Expand Down Expand Up @@ -213,7 +201,7 @@ def make_a_builder(sleep_seconds=0):

@pytest.mark.requires_rmq
@pytest.mark.usefixtures('started_daemon_client')
def test_process_kill_failng_ebm(
def test_process_kill_failing_ebm(
fork_worker_context, submit_and_await, aiida_code_installed, run_cli_command, monkeypatch
):
"""9) Kill a process that is paused after EBM (5 times failed). It should be possible to kill it normally.
Expand All @@ -232,6 +220,7 @@ def make_a_builder(sleep_seconds=0):

kill_timeout = 10

# TODO instead of mocking it why didn't we just set the paramaters to 1 second?
monkeypatch_args = ('aiida.engine.utils.exponential_backoff_retry', MockFunctions.mock_exponential_backoff_retry)
with fork_worker_context(monkeypatch.setattr, monkeypatch_args):
node = submit_and_await(make_a_builder(), ProcessState.WAITING)
Expand All @@ -242,6 +231,11 @@ def make_a_builder(sleep_seconds=0):
)

run_cli_command(cmd_process.process_kill, [str(node.pk), '--wait'])
# It should *not* be killable after the EBM expected
await_condition(lambda: not node.is_killed, timeout=kill_timeout)

# It should be killable with the force kill option
run_cli_command(cmd_process.process_kill, [str(node.pk), '-F', '--wait'])
await_condition(lambda: node.is_killed, timeout=kill_timeout)


Expand Down
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,3 +953,17 @@ def cat_path() -> Path:
run_process = subprocess.run(['which', 'cat'], capture_output=True, check=True)
path = run_process.stdout.decode('utf-8').strip()
return Path(path)


def await_condition(condition: t.Callable, timeout: int = 1) -> t.Any:
"""Wait for the ``condition`` to evaluate to ``True`` within the ``timeout`` or raise."""
import time

start_time = time.time()

while not (result := condition()):
if time.time() - start_time > timeout:
raise RuntimeError(f'waiting for {condition} to evaluate to `True` timed out after {timeout} seconds.')
time.sleep(0.1)

return result
Loading
Loading