Skip to content

Commit f76c985

Browse files
committed
wip
1 parent e257b3c commit f76c985

File tree

1 file changed

+72
-36
lines changed

1 file changed

+72
-36
lines changed

src/aiida/engine/processes/process.py

Lines changed: 72 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -329,51 +329,87 @@ def load_instance_state(
329329

330330
self.node.logger.info(f'Loaded process<{self.node.pk}> from saved state')
331331

332+
async def _launch_task(self, coro, *args, **kwargs):
333+
"""Launch a coroutine as a task, making sure to make it interruptable."""
334+
import functools
335+
from aiida.engine.utils import interruptable_task
336+
task_fn = functools.partial(coro, *args, **kwargs)
337+
try:
338+
self._task = interruptable_task(task_fn)
339+
result = await self._task
340+
return result
341+
finally:
342+
self._task = None
343+
332344
def kill(self, msg_text: str | None = None, force_kill: bool = False) -> Union[bool, plumpy.futures.Future]:
333345
"""Kill the process and all the children calculations it called
334346
335347
:param msg: message
336348
"""
337-
self.node.logger.info(f'Request to kill Process<{self.node.pk}>')
338-
339-
had_been_terminated = self.has_terminated()
340-
349+
from .calcjobs.tasks import task_kill_job
350+
351+
if self._killing and not force_kill:
352+
# if already killing we have triggered the Interruption
353+
coro = self._launch_task(task_kill_job, self.node, self.runner.transport)
354+
task = asyncio.create_task(coro)
355+
_ = self.loop.run_until_complete(task)
356+
# the parent class invokes the transition to being killed
341357
result = super().kill(msg_text, force_kill)
342358

343-
# Only kill children if we could be killed ourselves
344-
if result is not False and not had_been_terminated:
345-
killing = []
346-
for child in self.node.called:
347-
if self.runner.controller is None:
348-
self.logger.info('no controller available to kill child<%s>', child.pk)
349-
continue
350-
try:
351-
result = self.runner.controller.kill_process(child.pk, msg_text=f'Killed by parent<{self.node.pk}>')
352-
result = asyncio.wrap_future(result) # type: ignore[arg-type]
353-
if asyncio.isfuture(result):
354-
killing.append(result)
355-
except ConnectionClosed:
356-
self.logger.info('no connection available to kill child<%s>', child.pk)
357-
except UnroutableError:
358-
self.logger.info('kill signal was unable to reach child<%s>', child.pk)
359-
360-
if asyncio.isfuture(result):
361-
# We ourselves are waiting to be killed so add it to the list
362-
killing.append(result)
363-
364-
if killing:
365-
# We are waiting for things to be killed, so return the 'gathered' future
366-
kill_future = plumpy.futures.gather(*killing)
367-
result = self.loop.create_future()
368-
369-
def done(done_future: plumpy.futures.Future):
370-
is_all_killed = all(done_future.result())
371-
result.set_result(is_all_killed)
372-
373-
kill_future.add_done_callback(done)
374-
375359
return result
376360

361+
# TODO might need to be merged with `task_kill_job`
362+
#if self._killing is not None:
363+
# self._killing.set_result(True)
364+
#else:
365+
# self.node.logger.info(f'killed CalcJob<{self.node.pk}> but async future was None')
366+
367+
#self.node.logger.info(f'Request to kill Process<{self.node.pk}>')
368+
#from .
369+
#await self._launch_task(task_kill_job, node, transport_queue)
370+
#if self._killing is not None:
371+
# self._killing.set_result(True)
372+
#else:
373+
# logger.info(f'killed CalcJob<{node.pk}> but async future was None')
374+
375+
#had_been_terminated = self.has_terminated()
376+
377+
#result = super().kill(msg_text, force_kill)
378+
379+
## Only kill children if we could be killed ourselves
380+
#if result is not False and not had_been_terminated:
381+
# killing = []
382+
# for child in self.node.called:
383+
# if self.runner.controller is None:
384+
# self.logger.info('no controller available to kill child<%s>', child.pk)
385+
# continue
386+
# try:
387+
# result = self.runner.controller.kill_process(child.pk, msg_text=f'Killed by parent<{self.node.pk}>')
388+
# result = asyncio.wrap_future(result) # type: ignore[arg-type]
389+
# if asyncio.isfuture(result):
390+
# killing.append(result)
391+
# except ConnectionClosed:
392+
# self.logger.info('no connection available to kill child<%s>', child.pk)
393+
# except UnroutableError:
394+
# self.logger.info('kill signal was unable to reach child<%s>', child.pk)
395+
396+
# if asyncio.isfuture(result):
397+
# # We ourselves are waiting to be killed so add it to the list
398+
# killing.append(result)
399+
400+
# if killing:
401+
# # We are waiting for things to be killed, so return the 'gathered' future
402+
# kill_future = plumpy.futures.gather(*killing)
403+
# result = self.loop.create_future()
404+
405+
# def done(done_future: plumpy.futures.Future):
406+
# is_all_killed = all(done_future.result())
407+
# result.set_result(is_all_killed)
408+
409+
# kill_future.add_done_callback(done)
410+
411+
#return result
412+
377413
@override
378414
def out(self, output_port: str, value: Any = None) -> None:
379415
"""Attach output to output port.

0 commit comments

Comments
 (0)