@@ -329,50 +329,134 @@ def load_instance_state(
329329
330330 self .node .logger .info (f'Loaded process<{ self .node .pk } > from saved state' )
331331
332+ async def _launch_task (self , coro , * args , ** kwargs ):
333+ """Launch a coroutine as a task, making sure to make it interruptable."""
334+ import functools
335+
336+ from aiida .engine .utils import interruptable_task
337+
338+ task_fn = functools .partial (coro , * args , ** kwargs )
339+ try :
340+ self ._task = interruptable_task (task_fn )
341+ result = await self ._task
342+ return result
343+ finally :
344+ self ._task = None
345+
332346 def kill (self , msg_text : str | None = None , force_kill : bool = False ) -> Union [bool , plumpy .futures .Future ]:
333347 """Kill the process and all the children calculations it called
334348
335349 :param msg: message
336350 """
337- self .node .logger .info (f'Request to kill Process<{ self .node .pk } >' )
338-
339- had_been_terminated = self .has_terminated ()
340-
341- result = super ().kill (msg_text , force_kill )
342-
343- # Only kill children if we could be killed ourselves
344- if result is not False and not had_been_terminated :
345- killing = []
346- for child in self .node .called :
347- if self .runner .controller is None :
348- self .logger .info ('no controller available to kill child<%s>' , child .pk )
349- continue
350- try :
351- result = self .runner .controller .kill_process (child .pk , msg_text = f'Killed by parent<{ self .node .pk } >' )
352- result = asyncio .wrap_future (result ) # type: ignore[arg-type]
353- if asyncio .isfuture (result ):
354- killing .append (result )
355- except ConnectionClosed :
356- self .logger .info ('no connection available to kill child<%s>' , child .pk )
357- except UnroutableError :
358- self .logger .info ('kill signal was unable to reach child<%s>' , child .pk )
359-
360- if asyncio .isfuture (result ):
361- # We ourselves are waiting to be killed so add it to the list
362- killing .append (result )
363-
364- if killing :
365- # We are waiting for things to be killed, so return the 'gathered' future
366- kill_future = plumpy .futures .gather (* killing )
367- result = self .loop .create_future ()
368-
369- def done (done_future : plumpy .futures .Future ):
370- is_all_killed = all (done_future .result ())
371- result .set_result (is_all_killed )
372-
373- kill_future .add_done_callback (done )
374-
375- return result
351+ if self .killed ():
352+ self .node .logger .info (f'Request to kill Process<{ self .node .pk } > but process has already been killed.' )
353+ return True
354+ elif self .has_terminated ():
355+ self .node .logger .info (f'Request to kill Process<{ self .node .pk } > but process has already terminated.' )
356+ return False
357+ self .node .logger .info (f'Request to kill Process<{ self .node .pk } >.' )
358+
359+ # PR_COMMENT We need to kill the children now before because we transition to kill after the first kill
360+ # This became buggy in the last PR by allowing the user to reusing killing commands (if _killing do
361+ # nothing). Since we want to now allow the user to resend killing commands with different options we
362+ # have to kill first the children, or we still kill the children even when this process has been
363+ # killed. Otherwise you have the problematic scenario: Process is killed but did not kill the
364+ # children yet, kill timeouts, we kill again, but the parent process is already killed so it will
365+ # never enter this code
366+ #
367+ # TODO if tests just pass it could mean that this is not well tested, need to check if there is a test
368+ killing = []
369+ for child in self .node .called :
370+ if self .runner .controller is None :
371+ self .logger .info ('no controller available to kill child<%s>' , child .pk )
372+ continue
373+ try :
374+ result = self .runner .controller .kill_process (child .pk , msg_text = f'Killed by parent<{ self .node .pk } >' )
375+ result = asyncio .wrap_future (result ) # type: ignore[arg-type]
376+ if asyncio .isfuture (result ):
377+ killing .append (result )
378+ except ConnectionClosed :
379+ self .logger .info ('no connection available to kill child<%s>' , child .pk )
380+ except UnroutableError :
381+ self .logger .info ('kill signal was unable to reach child<%s>' , child .pk )
382+
383+ # TODO need to check this part, might be overengineered
384+ # if asyncio.isfuture(result):
385+ # # We ourselves are waiting to be killed so add it to the list
386+ # killing.append(result)
387+
388+ if killing :
389+ # We are waiting for things to be killed, so return the 'gathered' future
390+ kill_future = plumpy .futures .gather (* killing )
391+ # TODO need to understand what thisi
392+ # breakpoint()
393+ result = self .loop .create_future ()
394+
395+ def done (done_future : plumpy .futures .Future ):
396+ is_all_killed = all (done_future .result ())
397+ result .set_result (is_all_killed )
398+
399+ kill_future .add_done_callback (done )
400+
401+ # PR_COMMENT We do not do this anymore. The original idea was to resend the killing interruption so the state
402+ # can continue freeing its resources using an EBM with new parameters as the user can change these
403+ # between kills by changing the config parameters. However this was not working properly because the
404+ # process state goes only the first time it receives a KillInterruption into the EBM. This is because
405+ # the EBM is activated within try-catch block.
406+ # try:
407+ # do_work() # <-- now we send the interrupt exception
408+ # except KillInterruption:
409+ # cancel_scheduler_job_in_ebm # <-- if we cancel it will just stop this
410+ #
411+ # Not sure why I did not detect this during my tries. We could also do a while loop of interrupts
412+ # but I think it is generally not good design that the process state cancels the scheduler job while
413+ # here we kill the process. It adds another actor responsible for killing the process correctly
414+ # making it more complex than necessary.
415+ #
416+ # Cancel any old killing command to send a new one
417+ # if self._killing:
418+ # self._killing.cancel()
419+
420+ # Send kill interruption to the tasks in the event loop so they stop
421+ # This is not blocking, so the interruption is happening concurrently
422+ if self ._stepping :
423+ # Ask the step function to pause by setting this flag and giving the
424+ # caller back a future
425+ interrupt_exception = plumpy .process_states .KillInterruption (msg_text , force_kill )
426+ # PR COMMENT we do not set interrupt action because plumpy is very smart it uses the interrupt action to set
427+ # next state in the stepping, but we do not want to step to the next state through the plumpy
428+ # state machine, we want to control this here and only here
429+ # self._set_interrupt_action_from_exception(interrupt_exception)
430+ # self._killing = self._interrupt_action
431+ self ._state .interrupt (interrupt_exception )
432+ # return cast(plumpy.futures.CancellableAction, self._interrupt_action)
433+
434+ # Kill jobs from scheduler associated with this process.
435+ # This is blocking so we only continue when the scheduler job has been killed.
436+ if not force_kill :
437+ # TODO put this function into more common place
438+ from .calcjobs .tasks import task_kill_job
439+
440+ # if already killing we have triggered the Interruption
441+ coro = self ._launch_task (task_kill_job , self .node , self .runner .transport )
442+ task = asyncio .create_task (coro )
443+ # task_kill_job is raising an error if not successful, e.g. EBM fails.
444+ # PR COMMENT we just return False and write why the kill fails, it does not make sense to me to put the
445+ # process to excepted. Maybe you fix your internet connection and want to try it again.
446+ # We have force-kill now if the user wants to enforce a killing
447+ try :
448+ self .loop .run_until_complete (task )
449+ except Exception as exc :
450+ self .node .logger .error (f'While cancelling job error was raised: { exc !s} ' )
451+ return False
452+
453+ # Transition to killed process state
454+ # This is blocking so we only continue when we are in killed state
455+ msg = plumpy .process_comms .MessageBuilder .kill (text = msg_text , force_kill = force_kill )
456+ new_state = self ._create_state_instance (plumpy .process_states .ProcessState .KILLED , msg = msg )
457+ self .transition_to (new_state )
458+
459+ return True
376460
377461 @override
378462 def out (self , output_port : str , value : Any = None ) -> None :
0 commit comments