Skip to content

Commit 305cdf0

Browse files
authored
Merge pull request #634 from Bodo-inc/sahil/px-keyboard-int
Allow sending signals to engines on Keyboard Interrupt in %%px
2 parents 6cd55b0 + 579e6ff commit 305cdf0

File tree

2 files changed

+150
-42
lines changed

2 files changed

+150
-42
lines changed

ipyparallel/client/magics.py

Lines changed: 83 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def nullcontext():
4949
from IPython.core.magic import Magics
5050
from IPython.core import magic_arguments
5151
from .. import error
52+
import sys
5253

5354
# -----------------------------------------------------------------------------
5455
# Definitions of magic functions for use with IPython
@@ -156,6 +157,16 @@ def exec_args(f):
156157
Use -1 for no progress, 0 for always showing progress immediately.
157158
""",
158159
),
160+
magic_arguments.argument(
161+
'--signal-on-interrupt',
162+
dest='signal_on_interrupt',
163+
type=str,
164+
default=None,
165+
help="""Send signal to engines on Keyboard Interrupt. By default a SIGINT is sent.
166+
Note that this is only applicable when running in blocking mode.
167+
Choices: SIGINT, 2, SIGKILL, 9, 0 (nop), etc.
168+
""",
169+
),
159170
]
160171
for a in args:
161172
f = a(f)
@@ -235,6 +246,8 @@ class ParallelMagics(Magics):
235246
stream_ouput = True
236247
# seconds to wait before showing progress bar for blocking execution
237248
progress_after_seconds = 2
249+
# signal to send to engines on keyboard-interrupt
250+
signal_on_interrupt = "SIGINT"
238251

239252
def __init__(self, shell, view, suffix=''):
240253
self.view = view
@@ -267,6 +280,11 @@ def _eval_target_str(self, ts):
267280
targets = eval(ts)
268281
return targets
269282

283+
def _eval_signal_str(self, sig_str: str):
284+
if sig_str.isdigit():
285+
return int(sig_str)
286+
return sig_str
287+
270288
@magic_arguments.magic_arguments()
271289
@exec_args
272290
def pxconfig(self, line):
@@ -280,6 +298,8 @@ def pxconfig(self, line):
280298
self.verbose = args.set_verbose
281299
if args.stream is not None:
282300
self.stream_ouput = args.stream
301+
if args.signal_on_interrupt is not None:
302+
self.signal_on_interrupt = self._eval_signal_str(args.signal_on_interrupt)
283303

284304
if args.progress_after_seconds is not None:
285305
self.progress_after_seconds = args.progress_after_seconds
@@ -339,12 +359,18 @@ def parallel_execute(
339359
save_name=None,
340360
stream_output=None,
341361
progress_after=None,
362+
signal_on_interrupt=None,
342363
):
343364
"""implementation used by %px and %%parallel"""
344365

345366
# defaults:
346367
block = self.view.block if block is None else block
347368
stream_output = self.stream_ouput if stream_output is None else stream_output
369+
signal_on_interrupt = (
370+
self.signal_on_interrupt
371+
if signal_on_interrupt is None
372+
else signal_on_interrupt
373+
)
348374

349375
base = "Parallel" if block else "Async parallel"
350376

@@ -364,49 +390,60 @@ def parallel_execute(
364390
self.shell.user_ns[save_name] = result
365391

366392
if block:
367-
368-
if progress_after is None:
369-
progress_after = self.progress_after_seconds
370-
371-
cm = result.stream_output() if stream_output else nullcontext()
372-
with cm:
373-
finished_waiting = False
374-
if progress_after > 0:
375-
# finite progress-after timeout
376-
# wait for 'quick' results before showing progress
377-
tic = time.perf_counter()
378-
deadline = tic + progress_after
379-
try:
380-
result.get(timeout=progress_after)
381-
remaining = max(deadline - time.perf_counter(), 0)
382-
result.wait_for_output(timeout=remaining)
383-
except TimeoutError:
384-
pass
385-
except error.CompositeError as e:
386-
if stream_output:
387-
# already streamed, show an abbreviated result
388-
raise error.AlreadyDisplayedError(e) from None
389-
else:
390-
raise
391-
else:
392-
finished_waiting = True
393-
394-
if not finished_waiting:
395-
if progress_after >= 0:
396-
# not an immediate result, start interactive progress
397-
result.wait_interactive()
398-
result.wait_for_output()
399-
try:
400-
result.get()
401-
except error.CompositeError as e:
402-
if stream_output:
403-
# already streamed, show an abbreviated result
404-
raise error.AlreadyDisplayedError(e) from None
393+
try:
394+
if progress_after is None:
395+
progress_after = self.progress_after_seconds
396+
397+
cm = result.stream_output() if stream_output else nullcontext()
398+
with cm:
399+
finished_waiting = False
400+
if progress_after > 0:
401+
# finite progress-after timeout
402+
# wait for 'quick' results before showing progress
403+
tic = time.perf_counter()
404+
deadline = tic + progress_after
405+
try:
406+
result.get(timeout=progress_after)
407+
remaining = max(deadline - time.perf_counter(), 0)
408+
result.wait_for_output(timeout=remaining)
409+
except TimeoutError:
410+
pass
411+
except error.CompositeError as e:
412+
if stream_output:
413+
# already streamed, show an abbreviated result
414+
raise error.AlreadyDisplayedError(e) from None
415+
else:
416+
raise
405417
else:
406-
raise
407-
# Skip redisplay if streaming output
408-
if not stream_output:
409-
result.display_outputs(groupby)
418+
finished_waiting = True
419+
420+
if not finished_waiting:
421+
if progress_after >= 0:
422+
# not an immediate result, start interactive progress
423+
result.wait_interactive()
424+
result.wait_for_output()
425+
try:
426+
result.get()
427+
except error.CompositeError as e:
428+
if stream_output:
429+
# already streamed, show an abbreviated result
430+
raise error.AlreadyDisplayedError(e) from None
431+
else:
432+
raise
433+
# Skip redisplay if streaming output
434+
if not stream_output:
435+
result.display_outputs(groupby)
436+
except KeyboardInterrupt:
437+
if signal_on_interrupt is not None:
438+
print(
439+
f"Received Keyboard Interrupt. Sending signal {signal_on_interrupt} to engines...",
440+
file=sys.stderr,
441+
)
442+
self.view.client.send_signal(
443+
signal_on_interrupt, targets=targets, block=True
444+
)
445+
else:
446+
raise
410447
else:
411448
# return AsyncResult only on non-blocking submission
412449
return result
@@ -438,6 +475,9 @@ def cell_px(self, line='', cell=None):
438475
if args.targets:
439476
save_targets = self.view.targets
440477
self.view.targets = self._eval_target_str(args.targets)
478+
signal_on_interrupt = None
479+
if args.signal_on_interrupt:
480+
signal_on_interrupt = self._eval_signal_str(args.signal_on_interrupt)
441481
# if running local, don't block until after local has run
442482
block = False if args.local else args.block
443483
try:
@@ -448,6 +488,7 @@ def cell_px(self, line='', cell=None):
448488
save_name=args.save_name,
449489
stream_output=args.stream,
450490
progress_after=args.progress_after_seconds,
491+
signal_on_interrupt=signal_on_interrupt,
451492
)
452493
finally:
453494
if args.targets:

ipyparallel/tests/test_magics.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Test Parallel magics"""
22
import re
3+
import signal
34
import sys
45
import time
56

@@ -480,3 +481,69 @@ def test_cellpx_block(self):
480481
pass
481482
self.assertNotIn('Async', io.stdout)
482483
self.assertEqual(view.block, False)
484+
485+
def cellpx_keyboard_interrupt_test_helper(self, sig=None):
486+
"""%%px with Keyboard Interrupt on blocking execution"""
487+
488+
ip = get_ipython()
489+
v = self.client[:]
490+
v.block = True
491+
v.activate()
492+
493+
def _sigalarm(sig, frame):
494+
raise KeyboardInterrupt
495+
496+
signal.signal(signal.SIGALRM, _sigalarm)
497+
signal.alarm(2)
498+
with capture_output(display=False) as io:
499+
ip.run_cell_magic(
500+
"px",
501+
"" if sig is None else f"--signal-on-interrupt {sig}",
502+
"print('Entering...'); import time; time.sleep(5); print('Exiting...');",
503+
)
504+
505+
print(io.stdout)
506+
print(io.stderr, file=sys.stderr)
507+
assert (
508+
'Received Keyboard Interrupt. Sending signal {} to engines...'.format(
509+
"SIGINT" if sig is None else sig
510+
)
511+
in io.stderr
512+
)
513+
assert 'Exiting...' not in io.stdout
514+
515+
@pytest.mark.skipif(
516+
sys.platform.startswith("win"), reason="Signal tests don't pass on Windows yet"
517+
)
518+
def test_cellpx_keyboard_interrupt_default(self):
519+
self.cellpx_keyboard_interrupt_test_helper()
520+
521+
@pytest.mark.skipif(
522+
sys.platform.startswith("win"), reason="Signal tests don't pass on Windows yet"
523+
)
524+
def test_cellpx_keyboard_interrupt_SIGINT(self):
525+
self.cellpx_keyboard_interrupt_test_helper("SIGINT")
526+
527+
@pytest.mark.skipif(
528+
sys.platform.startswith("win"), reason="Signal tests don't pass on Windows yet"
529+
)
530+
def test_cellpx_keyboard_interrupt_signal_2(self):
531+
self.cellpx_keyboard_interrupt_test_helper("2")
532+
533+
@pytest.mark.skipif(
534+
sys.platform.startswith("win"), reason="Signal tests don't pass on Windows yet"
535+
)
536+
def test_cellpx_keyboard_interrupt_signal_0(self):
537+
self.cellpx_keyboard_interrupt_test_helper("0")
538+
539+
@pytest.mark.skipif(
540+
sys.platform.startswith("win"), reason="Signal tests don't pass on Windows yet"
541+
)
542+
def test_cellpx_keyboard_interrupt_SIGKILL(self):
543+
self.cellpx_keyboard_interrupt_test_helper("SIGKILL")
544+
545+
@pytest.mark.skipif(
546+
sys.platform.startswith("win"), reason="Signal tests don't pass on Windows yet"
547+
)
548+
def test_cellpx_keyboard_interrupt_signal_9(self):
549+
self.cellpx_keyboard_interrupt_test_helper("9")

0 commit comments

Comments
 (0)