Skip to content

Commit 4a462a0

Browse files
committed
Adapt after merge the message protocol PR
1 parent b9dd887 commit 4a462a0

File tree

3 files changed

+38
-12
lines changed

3 files changed

+38
-12
lines changed

src/plumpy/process_comms.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,15 +229,17 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult':
229229
result = await asyncio.wrap_future(future)
230230
return result
231231

232-
async def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> 'ProcessResult':
232+
async def kill_process(
233+
self, pid: 'PID_TYPE', msg_text: Optional[str] = None, force_kill: bool = False
234+
) -> 'ProcessResult':
233235
"""
234236
Kill the process
235237
236238
:param pid: the pid of the process to kill
237239
:param msg: optional kill message
238240
:return: True if killed, False otherwise
239241
"""
240-
msg = MessageBuilder.kill(text=msg_text)
242+
msg = MessageBuilder.kill(text=msg_text, force_kill=force_kill)
241243

242244
# Wait for the communication to go through
243245
kill_future = self._communicator.rpc_send(pid, msg)
@@ -401,15 +403,15 @@ def play_all(self) -> None:
401403
"""
402404
self._communicator.broadcast_send(None, subject=Intent.PLAY)
403405

404-
def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwipy.Future:
406+
def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None, force_kill: bool = False) -> kiwipy.Future:
405407
"""
406408
Kill the process
407409
408410
:param pid: the pid of the process to kill
409411
:param msg: optional kill message
410412
:return: a response future from the process to be killed
411413
"""
412-
msg = MessageBuilder.kill(text=msg_text)
414+
msg = MessageBuilder.kill(text=msg_text, force_kill=force_kill)
413415
return self._communicator.rpc_send(pid, msg)
414416

415417
def kill_all(self, msg_text: Optional[str]) -> None:

src/plumpy/process_states.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,11 @@ def __init__(self, msg_text: str | None):
6161

6262

6363
class ForceKillInterruption(Interruption):
64-
pass
64+
def __init__(self, msg_text: str | None):
65+
super().__init__()
66+
msg = MessageBuilder.kill(text=msg_text)
67+
68+
self.msg: MessageType = msg
6569

6670

6771
class PauseInterruption(Interruption):

src/plumpy/processes.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,11 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: MessageType) -> Any:
965965
if intent == process_comms.Intent.PAUSE:
966966
return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None))
967967
if intent == process_comms.Intent.KILL:
968-
return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None))
968+
return self._schedule_rpc(
969+
self.kill,
970+
msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None),
971+
force_kill=msg.get(process_comms.FORCE_KILL_KEY),
972+
)
969973
if intent == process_comms.Intent.STATUS:
970974
status_info: Dict[str, Any] = {}
971975
self.get_status_info(status_info)
@@ -998,7 +1002,11 @@ def broadcast_receive(
9981002
if subject == process_comms.Intent.PAUSE:
9991003
return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None))
10001004
if subject == process_comms.Intent.KILL:
1001-
return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None))
1005+
return self._schedule_rpc(
1006+
self.kill,
1007+
msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None),
1008+
force_kill=msg.get(process_comms.FORCE_KILL_KEY),
1009+
)
10021010
return None
10031011

10041012
def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> kiwipy.Future:
@@ -1222,12 +1230,12 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac
12221230
)
12231231
self.transition_to(new_state)
12241232

1225-
def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]:
1233+
def kill(self, msg_text: Optional[str] = None, force_kill: bool = False) -> Union[bool, asyncio.Future]:
12261234
"""
12271235
Kill the process
12281236
:param msg: An optional kill message
1237+
:param force_kill: An optional whether force kill the process
12291238
"""
1230-
force_kill = isinstance(msg, str) and '-F' in msg
12311239

12321240
if self.state == process_states.ProcessState.KILLED:
12331241
# Already killed
@@ -1243,20 +1251,32 @@ def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]:
12431251

12441252
if force_kill:
12451253
# Skip interrupting the state and go straight to killed
1246-
interrupt_exception = process_states.ForceKillInterruption(msg)
1254+
interrupt_exception = process_states.ForceKillInterruption(msg_text)
1255+
# XXX: this line was not in ali's PR but to make the change align with _stepping,
1256+
# it seems it is needed to set the _interrupt_action to be used line after.
1257+
# Requires more check to test with aiida-core's PR.
1258+
#
1259+
# self._set_interrupt_action_from_exception(interrupt_exception)
1260+
#
12471261
self._killing = self._interrupt_action
12481262
self._state.interrupt(interrupt_exception)
12491263

1250-
elif self._stepping:
1264+
msg = MessageBuilder.kill(msg_text, force_kill=True)
1265+
new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=msg)
1266+
self.transition_to(new_state)
1267+
return True
1268+
1269+
if self._stepping:
12511270
# Ask the step function to pause by setting this flag and giving the
12521271
# caller back a future
12531272
interrupt_exception = process_states.KillInterruption(msg_text)
12541273
self._set_interrupt_action_from_exception(interrupt_exception)
12551274
self._killing = self._interrupt_action
12561275
self._state.interrupt(interrupt_exception)
1276+
12571277
return cast(futures.CancellableAction, self._interrupt_action)
12581278

1259-
msg = MessageBuilder.kill(msg_text)
1279+
msg = MessageBuilder.kill(msg_text, force_kill=False)
12601280
new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=msg)
12611281
self.transition_to(new_state)
12621282
return True

0 commit comments

Comments
 (0)