@@ -329,51 +329,87 @@ def load_instance_state(
329329
330330 self .node .logger .info (f'Loaded process<{ self .node .pk } > from saved state' )
331331
332+ async def _launch_task (self , coro , * args , ** kwargs ):
333+ """Launch a coroutine as a task, making sure to make it interruptable."""
334+ import functools
335+ from aiida .engine .utils import interruptable_task
336+ task_fn = functools .partial (coro , * args , ** kwargs )
337+ try :
338+ self ._task = interruptable_task (task_fn )
339+ result = await self ._task
340+ return result
341+ finally :
342+ self ._task = None
343+
332344 def kill (self , msg_text : str | None = None , force_kill : bool = False ) -> Union [bool , plumpy .futures .Future ]:
333345 """Kill the process and all the children calculations it called
334346
335347 :param msg: message
336348 """
337- self .node .logger .info (f'Request to kill Process<{ self .node .pk } >' )
338-
339- had_been_terminated = self .has_terminated ()
340-
349+ from .calcjobs .tasks import task_kill_job
350+
351+ if self ._killing and not force_kill :
352+ # if already killing we have triggered the Interruption
353+ coro = self ._launch_task (task_kill_job , self .node , self .runner .transport )
354+ task = asyncio .create_task (coro )
355+ _ = self .loop .run_until_complete (task )
356+ # the parent class invokes the transition to being killed
341357 result = super ().kill (msg_text , force_kill )
342358
343- # Only kill children if we could be killed ourselves
344- if result is not False and not had_been_terminated :
345- killing = []
346- for child in self .node .called :
347- if self .runner .controller is None :
348- self .logger .info ('no controller available to kill child<%s>' , child .pk )
349- continue
350- try :
351- result = self .runner .controller .kill_process (child .pk , msg_text = f'Killed by parent<{ self .node .pk } >' )
352- result = asyncio .wrap_future (result ) # type: ignore[arg-type]
353- if asyncio .isfuture (result ):
354- killing .append (result )
355- except ConnectionClosed :
356- self .logger .info ('no connection available to kill child<%s>' , child .pk )
357- except UnroutableError :
358- self .logger .info ('kill signal was unable to reach child<%s>' , child .pk )
359-
360- if asyncio .isfuture (result ):
361- # We ourselves are waiting to be killed so add it to the list
362- killing .append (result )
363-
364- if killing :
365- # We are waiting for things to be killed, so return the 'gathered' future
366- kill_future = plumpy .futures .gather (* killing )
367- result = self .loop .create_future ()
368-
369- def done (done_future : plumpy .futures .Future ):
370- is_all_killed = all (done_future .result ())
371- result .set_result (is_all_killed )
372-
373- kill_future .add_done_callback (done )
374-
375359 return result
376360
361+ # TODO might need to be merged with `task_kill_job`
362+ #if self._killing is not None:
363+ # self._killing.set_result(True)
364+ #else:
365+ # self.node.logger.info(f'killed CalcJob<{self.node.pk}> but async future was None')
366+
367+ #self.node.logger.info(f'Request to kill Process<{self.node.pk}>')
368+ #from .
369+ #await self._launch_task(task_kill_job, node, transport_queue)
370+ #if self._killing is not None:
371+ # self._killing.set_result(True)
372+ #else:
373+ # logger.info(f'killed CalcJob<{node.pk}> but async future was None')
374+
375+ #had_been_terminated = self.has_terminated()
376+
377+ #result = super().kill(msg_text, force_kill)
378+
379+ ## Only kill children if we could be killed ourselves
380+ #if result is not False and not had_been_terminated:
381+ # killing = []
382+ # for child in self.node.called:
383+ # if self.runner.controller is None:
384+ # self.logger.info('no controller available to kill child<%s>', child.pk)
385+ # continue
386+ # try:
387+ # result = self.runner.controller.kill_process(child.pk, msg_text=f'Killed by parent<{self.node.pk}>')
388+ # result = asyncio.wrap_future(result) # type: ignore[arg-type]
389+ # if asyncio.isfuture(result):
390+ # killing.append(result)
391+ # except ConnectionClosed:
392+ # self.logger.info('no connection available to kill child<%s>', child.pk)
393+ # except UnroutableError:
394+ # self.logger.info('kill signal was unable to reach child<%s>', child.pk)
395+
396+ # if asyncio.isfuture(result):
397+ # # We ourselves are waiting to be killed so add it to the list
398+ # killing.append(result)
399+
400+ # if killing:
401+ # # We are waiting for things to be killed, so return the 'gathered' future
402+ # kill_future = plumpy.futures.gather(*killing)
403+ # result = self.loop.create_future()
404+
405+ # def done(done_future: plumpy.futures.Future):
406+ # is_all_killed = all(done_future.result())
407+ # result.set_result(is_all_killed)
408+
409+ # kill_future.add_done_callback(done)
410+
411+ #return result
412+
377413 @override
378414 def out (self , output_port : str , value : Any = None ) -> None :
379415 """Attach output to output port.
0 commit comments