Skip to content

Commit b9dd887

Browse files
authored
Merge branch 'master' into force-kill
2 parents f1f8095 + ecef9b9 commit b9dd887

File tree

6 files changed

+87
-59
lines changed

6 files changed

+87
-59
lines changed

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,9 @@ classifiers = [
2929
keywords = ['workflow', 'multithreaded', 'rabbitmq']
3030
requires-python = '>=3.8'
3131
dependencies = [
32-
'kiwipy[rmq]~=0.8.3',
32+
'kiwipy[rmq]~=0.8.5',
3333
'nest_asyncio~=1.5,>=1.5.1',
3434
'pyyaml~=6.0',
35-
# XXX: workaround for https://github.com/mosquito/aio-pika/issues/649
36-
'typing-extensions~=4.12',
3735
]
3836

3937
[project.urls]

src/plumpy/process_comms.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
ProcessStatus = Any
2929

3030
INTENT_KEY = 'intent'
31-
MESSAGE_KEY = 'message'
31+
MESSAGE_TEXT_KEY = 'message'
3232
FORCE_KILL_KEY = 'force_kill'
3333

3434

@@ -52,23 +52,23 @@ def play(cls, text: str | None = None) -> MessageType:
5252
"""The play message send over communicator."""
5353
return {
5454
INTENT_KEY: Intent.PLAY,
55-
MESSAGE_KEY: text,
55+
MESSAGE_TEXT_KEY: text,
5656
}
5757

5858
@classmethod
5959
def pause(cls, text: str | None = None) -> MessageType:
6060
"""The pause message send over communicator."""
6161
return {
6262
INTENT_KEY: Intent.PAUSE,
63-
MESSAGE_KEY: text,
63+
MESSAGE_TEXT_KEY: text,
6464
}
6565

6666
@classmethod
6767
def kill(cls, text: str | None = None, force_kill: bool = False) -> MessageType:
6868
"""The kill message send over communicator."""
6969
return {
7070
INTENT_KEY: Intent.KILL,
71-
MESSAGE_KEY: text,
71+
MESSAGE_TEXT_KEY: text,
7272
FORCE_KILL_KEY: force_kill,
7373
}
7474

@@ -77,7 +77,7 @@ def status(cls, text: str | None = None) -> MessageType:
7777
"""The status message send over communicator."""
7878
return {
7979
INTENT_KEY: Intent.STATUS,
80-
MESSAGE_KEY: text,
80+
MESSAGE_TEXT_KEY: text,
8181
}
8282

8383

@@ -200,15 +200,15 @@ async def get_status(self, pid: 'PID_TYPE') -> 'ProcessStatus':
200200
result = await asyncio.wrap_future(future)
201201
return result
202202

203-
async def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'ProcessResult':
203+
async def pause_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> 'ProcessResult':
204204
"""
205205
Pause the process
206206
207207
:param pid: the pid of the process to pause
208208
:param msg: optional pause message
209209
:return: True if paused, False otherwise
210210
"""
211-
msg = MessageBuilder.pause(text=msg)
211+
msg = MessageBuilder.pause(text=msg_text)
212212

213213
pause_future = self._communicator.rpc_send(pid, msg)
214214
# rpc_send return a thread future from communicator
@@ -229,16 +229,15 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult':
229229
result = await asyncio.wrap_future(future)
230230
return result
231231

232-
async def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> 'ProcessResult':
232+
async def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> 'ProcessResult':
233233
"""
234234
Kill the process
235235
236236
:param pid: the pid of the process to kill
237237
:param msg: optional kill message
238238
:return: True if killed, False otherwise
239239
"""
240-
if msg is None:
241-
msg = MessageBuilder.kill()
240+
msg = MessageBuilder.kill(text=msg_text)
242241

243242
# Wait for the communication to go through
244243
kill_future = self._communicator.rpc_send(pid, msg)
@@ -364,7 +363,7 @@ def get_status(self, pid: 'PID_TYPE') -> kiwipy.Future:
364363
"""
365364
return self._communicator.rpc_send(pid, MessageBuilder.status())
366365

367-
def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future:
366+
def pause_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwipy.Future:
368367
"""
369368
Pause the process
370369
@@ -373,16 +372,17 @@ def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fu
373372
:return: a response future from the process to be paused
374373
375374
"""
376-
msg = MessageBuilder.pause(text=msg)
375+
msg = MessageBuilder.pause(text=msg_text)
377376

378377
return self._communicator.rpc_send(pid, msg)
379378

380-
def pause_all(self, msg: Any) -> None:
379+
def pause_all(self, msg_text: Optional[str]) -> None:
381380
"""
382381
Pause all processes that are subscribed to the same communicator
383382
384383
:param msg: an optional pause message
385384
"""
385+
msg = MessageBuilder.pause(text=msg_text)
386386
self._communicator.broadcast_send(msg, subject=Intent.PAUSE)
387387

388388
def play_process(self, pid: 'PID_TYPE') -> kiwipy.Future:
@@ -401,28 +401,24 @@ def play_all(self) -> None:
401401
"""
402402
self._communicator.broadcast_send(None, subject=Intent.PLAY)
403403

404-
def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> kiwipy.Future:
404+
def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwipy.Future:
405405
"""
406406
Kill the process
407407
408408
:param pid: the pid of the process to kill
409409
:param msg: optional kill message
410410
:return: a response future from the process to be killed
411-
412411
"""
413-
if msg is None:
414-
msg = MessageBuilder.kill()
415-
412+
msg = MessageBuilder.kill(text=msg_text)
416413
return self._communicator.rpc_send(pid, msg)
417414

418-
def kill_all(self, msg: Optional[MessageType]) -> None:
415+
def kill_all(self, msg_text: Optional[str]) -> None:
419416
"""
420417
Kill all processes that are subscribed to the same communicator
421418
422419
:param msg: an optional pause message
423420
"""
424-
if msg is None:
425-
msg = MessageBuilder.kill()
421+
msg = MessageBuilder.kill(msg_text)
426422

427423
self._communicator.broadcast_send(msg, subject=Intent.KILL)
428424

src/plumpy/process_states.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,9 @@ class Interruption(Exception): # noqa: N818
5353

5454

5555
class KillInterruption(Interruption):
56-
def __init__(self, msg: MessageType | None):
56+
def __init__(self, msg_text: str | None):
5757
super().__init__()
58-
if msg is None:
59-
msg = MessageBuilder.kill()
58+
msg = MessageBuilder.kill(text=msg_text)
6059

6160
self.msg: MessageType = msg
6261

@@ -66,7 +65,11 @@ class ForceKillInterruption(Interruption):
6665

6766

6867
class PauseInterruption(Interruption):
69-
pass
68+
def __init__(self, msg_text: str | None):
69+
super().__init__()
70+
msg = MessageBuilder.pause(text=msg_text)
71+
72+
self.msg: MessageType = msg
7073

7174

7275
# region Commands

src/plumpy/processes.py

Lines changed: 57 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event
5555
from .base.utils import call_with_super_check, super_check
5656
from .event_helper import EventHelper
57-
from .process_comms import MESSAGE_KEY, MessageBuilder, MessageType
57+
from .process_comms import MESSAGE_TEXT_KEY, MessageBuilder, MessageType
5858
from .process_listener import ProcessListener
5959
from .process_spec import ProcessSpec
6060
from .utils import PID_TYPE, SAVED_STATE_TYPE, protected
@@ -344,8 +344,7 @@ def init(self) -> None:
344344

345345
def try_killing(future: futures.Future) -> None:
346346
if future.cancelled():
347-
msg = MessageBuilder.kill(text='Killed by future being cancelled')
348-
if not self.kill(msg):
347+
if not self.kill('Killed by future being cancelled'):
349348
self.logger.warning(
350349
'Process<%s>: Failed to kill process on future cancel',
351350
self.pid,
@@ -903,7 +902,7 @@ def on_kill(self, msg: Optional[MessageType]) -> None:
903902
if msg is None:
904903
msg_txt = ''
905904
else:
906-
msg_txt = msg[MESSAGE_KEY] or ''
905+
msg_txt = msg[MESSAGE_TEXT_KEY] or ''
907906

908907
self.set_status(msg_txt)
909908
self.future().set_exception(exceptions.KilledError(msg_txt))
@@ -944,7 +943,7 @@ def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> Non
944943

945944
# region Communication
946945

947-
def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> Any:
946+
def message_receive(self, _comm: kiwipy.Communicator, msg: MessageType) -> Any:
948947
"""
949948
Coroutine called when the process receives a message from the communicator
950949
@@ -964,9 +963,9 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An
964963
if intent == process_comms.Intent.PLAY:
965964
return self._schedule_rpc(self.play)
966965
if intent == process_comms.Intent.PAUSE:
967-
return self._schedule_rpc(self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None))
966+
return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None))
968967
if intent == process_comms.Intent.KILL:
969-
return self._schedule_rpc(self.kill, msg=msg)
968+
return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None))
970969
if intent == process_comms.Intent.STATUS:
971970
status_info: Dict[str, Any] = {}
972971
self.get_status_info(status_info)
@@ -976,7 +975,7 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An
976975
raise RuntimeError('Unknown intent')
977976

978977
def broadcast_receive(
979-
self, _comm: kiwipy.Communicator, body: Any, sender: Any, subject: Any, correlation_id: Any
978+
self, _comm: kiwipy.Communicator, msg: MessageType, sender: Any, subject: Any, correlation_id: Any
980979
) -> Optional[kiwipy.Future]:
981980
"""
982981
Coroutine called when the process receives a message from the communicator
@@ -990,16 +989,16 @@ def broadcast_receive(
990989
self.pid,
991990
subject,
992991
_comm,
993-
body,
992+
msg,
994993
)
995994

996995
# If we get a message we recognise then action it, otherwise ignore
997996
if subject == process_comms.Intent.PLAY:
998997
return self._schedule_rpc(self.play)
999998
if subject == process_comms.Intent.PAUSE:
1000-
return self._schedule_rpc(self.pause, msg=body)
999+
return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None))
10011000
if subject == process_comms.Intent.KILL:
1002-
return self._schedule_rpc(self.kill, msg=body)
1001+
return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None))
10031002
return None
10041003

10051004
def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> kiwipy.Future:
@@ -1021,11 +1020,37 @@ def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any)
10211020

10221021
async def run_callback() -> None:
10231022
with kiwipy.capture_exceptions(kiwi_future):
1024-
result = callback(*args, **kwargs)
1025-
while asyncio.isfuture(result):
1026-
result = await result
1023+
try:
1024+
result = callback(*args, **kwargs)
1025+
except Exception as exc:
1026+
import inspect
1027+
import traceback
1028+
1029+
# Get traceback as a string
1030+
tb_str = ''.join(traceback.format_exception(type(exc), exc, exc.__traceback__))
1031+
1032+
# Attempt to get file and line number where the callback is defined
1033+
# Note: This might fail for certain built-in or dynamically generated functions.
1034+
# If it fails, just skip that part.
1035+
try:
1036+
source_file = inspect.getfile(callback)
1037+
# getsourcelines returns a tuple (list_of_source_lines, starting_line_number)
1038+
_, start_line = inspect.getsourcelines(callback)
1039+
callback_location = f'{source_file}:{start_line}'
1040+
except Exception:
1041+
callback_location = '<unknown location>'
1042+
1043+
# Include the callback name, file/line info, and the full traceback in the message
1044+
raise RuntimeError(
1045+
f"Error invoking callback '{callback.__name__}' at {callback_location}.\n"
1046+
f'Exception: {type(exc).__name__}: {exc}\n\n'
1047+
f'Full Traceback:\n{tb_str}'
1048+
) from exc
1049+
else:
1050+
while asyncio.isfuture(result):
1051+
result = await result
10271052

1028-
kiwi_future.set_result(result)
1053+
kiwi_future.set_result(result)
10291054

10301055
# Schedule the task and give back a kiwi future
10311056
asyncio.run_coroutine_threadsafe(run_callback(), self.loop)
@@ -1071,7 +1096,7 @@ def transition_failed(
10711096
)
10721097
self.transition_to(new_state)
10731098

1074-
def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]:
1099+
def pause(self, msg_text: Optional[str] = None) -> Union[bool, futures.CancellableAction]:
10751100
"""Pause the process.
10761101
10771102
:param msg: an optional message to set as the status. The current status will be saved in the private
@@ -1095,22 +1120,29 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable
10951120
if self._stepping:
10961121
# Ask the step function to pause by setting this flag and giving the
10971122
# caller back a future
1098-
interrupt_exception = process_states.PauseInterruption(msg)
1123+
interrupt_exception = process_states.PauseInterruption(msg_text)
10991124
self._set_interrupt_action_from_exception(interrupt_exception)
11001125
self._pausing = self._interrupt_action
11011126
# Try to interrupt the state
11021127
self._state.interrupt(interrupt_exception)
11031128
return cast(futures.CancellableAction, self._interrupt_action)
11041129

1105-
return self._do_pause(msg)
1130+
msg = MessageBuilder.pause(msg_text)
1131+
return self._do_pause(state_msg=msg)
11061132

1107-
def _do_pause(self, state_msg: Optional[str], next_state: Optional[process_states.State] = None) -> bool:
1133+
def _do_pause(self, state_msg: Optional[MessageType], next_state: Optional[process_states.State] = None) -> bool:
11081134
"""Carry out the pause procedure, optionally transitioning to the next state first"""
11091135
try:
11101136
if next_state is not None:
11111137
self.transition_to(next_state)
1112-
call_with_super_check(self.on_pausing, state_msg)
1113-
call_with_super_check(self.on_paused, state_msg)
1138+
1139+
if state_msg is None:
1140+
msg_text = ''
1141+
else:
1142+
msg_text = state_msg[MESSAGE_TEXT_KEY]
1143+
1144+
call_with_super_check(self.on_pausing, msg_text)
1145+
call_with_super_check(self.on_paused, msg_text)
11141146
finally:
11151147
self._pausing = None
11161148

@@ -1125,7 +1157,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu
11251157
11261158
"""
11271159
if isinstance(exception, process_states.PauseInterruption):
1128-
do_pause = functools.partial(self._do_pause, str(exception))
1160+
do_pause = functools.partial(self._do_pause, exception.msg)
11291161
return futures.CancellableAction(do_pause, cookie=exception)
11301162

11311163
if isinstance(exception, (process_states.KillInterruption, process_states.ForceKillInterruption)):
@@ -1190,7 +1222,7 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac
11901222
)
11911223
self.transition_to(new_state)
11921224

1193-
def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]:
1225+
def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]:
11941226
"""
11951227
Kill the process
11961228
:param msg: An optional kill message
@@ -1218,12 +1250,13 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]
12181250
elif self._stepping:
12191251
# Ask the step function to pause by setting this flag and giving the
12201252
# caller back a future
1221-
interrupt_exception = process_states.KillInterruption(msg) # type: ignore
1253+
interrupt_exception = process_states.KillInterruption(msg_text)
12221254
self._set_interrupt_action_from_exception(interrupt_exception)
12231255
self._killing = self._interrupt_action
12241256
self._state.interrupt(interrupt_exception)
12251257
return cast(futures.CancellableAction, self._interrupt_action)
12261258

1259+
msg = MessageBuilder.kill(msg_text)
12271260
new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=msg)
12281261
self.transition_to(new_state)
12291262
return True

tests/rmq/test_process_comms.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,7 @@ async def test_kill_all(self, thread_communicator, sync_controller):
195195
for _ in range(10):
196196
procs.append(utils.WaitForSignalProcess(communicator=thread_communicator))
197197

198-
msg = process_comms.MessageBuilder.kill(text='bang bang, I shot you down')
199-
200-
sync_controller.kill_all(msg)
198+
sync_controller.kill_all(msg_text='bang bang, I shot you down')
201199
await utils.wait_util(lambda: all([proc.killed() for proc in procs]))
202200
assert all([proc.state == plumpy.ProcessState.KILLED for proc in procs])
203201

0 commit comments

Comments
 (0)