diff --git a/docs/source/reference-lowlevel.rst b/docs/source/reference-lowlevel.rst index 82bd8537d..15289ff85 100644 --- a/docs/source/reference-lowlevel.rst +++ b/docs/source/reference-lowlevel.rst @@ -344,6 +344,8 @@ Spawning threads .. autofunction:: start_thread_soon +.. _ki-handling: + Safer KeyboardInterrupt handling ================================ @@ -355,10 +357,21 @@ correctness invariants. On the other, if the user accidentally writes an infinite loop, we do want to be able to break out of that. Our solution is to install a default signal handler which checks whether it's safe to raise :exc:`KeyboardInterrupt` at the place where the -signal is received. If so, then we do; otherwise, we schedule a -:exc:`KeyboardInterrupt` to be delivered to the main task at the next -available opportunity (similar to how :exc:`~trio.Cancelled` is -delivered). +signal is received. If so, then we do. Otherwise, we cancel all tasks +and add `KeyboardInterrupt` as the result of :func:`trio.run`. + +.. note:: This behavior means it's not a good idea to try to catch + `KeyboardInterrupt` within a Trio task. Most Trio + programs are I/O-bound, so most interrupts will be received while + no task is running (because Trio is waiting for I/O). There's no + task that should obviously receive the interrupt in such cases, so + Trio doesn't raise it within a task at all: every task gets cancelled, + then `KeyboardInterrupt` is raised once that's complete. + + If you want to handle Ctrl+C by doing something other than "cancel + all tasks", then you should use :func:`~trio.open_signal_receiver` to + install a handler for `signal.SIGINT`. If you do that, then Ctrl+C will + go to your handler, and it can do whatever it wants. So that's great, but – how do we know whether we're in one of the sensitive parts of the program or not? diff --git a/newsfragments/733.breaking.rst b/newsfragments/733.breaking.rst new file mode 100644 index 000000000..cf87925ce --- /dev/null +++ b/newsfragments/733.breaking.rst @@ -0,0 +1,42 @@ +:ref:`Sometimes `, a Trio program receives an interrupt +signal (Ctrl+C) at a time when Python's default response (raising +`KeyboardInterrupt` immediately) might corrupt Trio's internal +state. Previously, Trio would handle this situation by raising the +`KeyboardInterrupt` at the next :ref:`checkpoint ` executed +by the main task (the one running the function you passed to :func:`trio.run`). +This was responsible for a lot of internal complexity and sometimes led to +surprising behavior. + +With this release, such a "deferred" `KeyboardInterrupt` is handled in a +different way: Trio will first cancel all running tasks, then raise +`KeyboardInterrupt` directly out of the call to :func:`trio.run`. +The difference is relevant if you have code that tries to catch +`KeyboardInterrupt` within Trio. This was never entirely robust, but it +previously might have worked in many cases, whereas now it will never +catch the interrupt. + +An example of code that mostly worked on previous releases, but won't +work on this release:: + + async def main(): + try: + await trio.sleep_forever() + except KeyboardInterrupt: + print("interrupted") + trio.run(main) + +The fix is to catch `KeyboardInterrupt` outside Trio:: + + async def main(): + await trio.sleep_forever() + try: + trio.run(main) + except KeyboardInterrupt: + print("interrupted") + +If that doesn't work for you (because you want to respond to +`KeyboardInterrupt` by doing something other than cancelling all +tasks), then you can start a task that uses +`trio.open_signal_receiver` to receive the interrupt signal ``SIGINT`` +directly and handle it however you wish. Such a task takes precedence +over Trio's default interrupt handling. diff --git a/src/trio/_core/_run.py b/src/trio/_core/_run.py index 564409963..dc82fb275 100644 --- a/src/trio/_core/_run.py +++ b/src/trio/_core/_run.py @@ -1662,17 +1662,6 @@ def raise_cancel() -> NoReturn: self._attempt_abort(raise_cancel) - def _attempt_delivery_of_pending_ki(self) -> None: - assert self._runner.ki_pending - if self._abort_func is None: - return - - def raise_cancel() -> NoReturn: - self._runner.ki_pending = False - raise KeyboardInterrupt - - self._attempt_abort(raise_cancel) - ################################################################ # The central Runner object @@ -1803,6 +1792,7 @@ class Runner: # type: ignore[explicit-any] system_context: contextvars.Context = attrs.field(kw_only=True) main_task: Task | None = None main_task_outcome: Outcome[object] | None = None + main_task_nursery: Nursery | None = None entry_queue: EntryQueue = attrs.Factory(EntryQueue) trio_token: TrioToken | None = None @@ -2137,12 +2127,12 @@ async def init( # All other system tasks run here: async with open_nursery() as self.system_nursery: # Only the main task runs here: - async with open_nursery() as main_task_nursery: + async with open_nursery() as self.main_task_nursery: try: self.main_task = self.spawn_impl( async_fn, args, - main_task_nursery, + self.main_task_nursery, None, ) except BaseException as exc: @@ -2199,30 +2189,13 @@ def current_trio_token(self) -> TrioToken: ki_pending: bool = False - # deliver_ki is broke. Maybe move all the actual logic and state into - # RunToken, and we'll only have one instance per runner? But then we can't - # have a public constructor. Eh, but current_run_token() returning a - # unique object per run feels pretty nice. Maybe let's just go for it. And - # keep the class public so people can isinstance() it if they want. - # This gets called from signal context def deliver_ki(self) -> None: self.ki_pending = True - with suppress(RunFinishedError): - self.entry_queue.run_sync_soon(self._deliver_ki_cb) + assert self.main_task_nursery is not None - def _deliver_ki_cb(self) -> None: - if not self.ki_pending: - return - # Can't happen because main_task and run_sync_soon_task are created at - # the same time -- so even if KI arrives before main_task is created, - # we won't get here until afterwards. - assert self.main_task is not None - if self.main_task_outcome is not None: - # We're already in the process of exiting -- leave ki_pending set - # and we'll check it again on our way out of run(). - return - self.main_task._attempt_delivery_of_pending_ki() + with suppress(RunFinishedError): + self.entry_queue.run_sync_soon(self.main_task_nursery.cancel_scope.cancel) ################ # Quiescing @@ -2884,10 +2857,6 @@ def unrolled_run( elif type(msg) is WaitTaskRescheduled: task._cancel_points += 1 task._abort_func = msg.abort_func - # KI is "outside" all cancel scopes, so check for it - # before checking for regular cancellation: - if runner.ki_pending and task is runner.main_task: - task._attempt_delivery_of_pending_ki() task._attempt_delivery_of_any_pending_cancel() elif type(msg) is PermanentlyDetachCoroutineObject: # Pretend the task just exited with the given outcome @@ -3020,9 +2989,7 @@ async def checkpoint() -> None: await cancel_shielded_checkpoint() task = current_task() task._cancel_points += 1 - if task._cancel_status.effectively_cancelled or ( - task is task._runner.main_task and task._runner.ki_pending - ): + if task._cancel_status.effectively_cancelled: cs = CancelScope(deadline=-inf) if ( task._cancel_status._scope._cancel_reason is None diff --git a/src/trio/_core/_tests/test_cancelled.py b/src/trio/_core/_tests/test_cancelled.py index 0c144c37f..1be9f2c84 100644 --- a/src/trio/_core/_tests/test_cancelled.py +++ b/src/trio/_core/_tests/test_cancelled.py @@ -199,14 +199,12 @@ async def child() -> None: raise ValueError -async def test_reason_delayed_ki() -> None: +def test_reason_delayed_ki() -> None: # simplified version of test_ki.test_ki_protection_works check #2 - parent_task = current_task() - async def sleeper(name: str) -> None: with pytest.raises( Cancelled, - match=rf"^cancelled due to KeyboardInterrupt from task {parent_task!r}$", + match=r"^cancelled due to KeyboardInterrupt$", ): while True: await trio.lowlevel.checkpoint() @@ -214,9 +212,11 @@ async def sleeper(name: str) -> None: async def raiser(name: str) -> None: ki_self() - with RaisesGroup(KeyboardInterrupt): + async def main() -> None: async with trio.open_nursery() as nursery: nursery.start_soon(sleeper, "s1") nursery.start_soon(sleeper, "s2") nursery.start_soon(trio.lowlevel.enable_ki_protection(raiser), "r1") - # __aexit__ blocks, and then receives the KI + + with pytest.raises(KeyboardInterrupt): + trio.run(main) diff --git a/src/trio/_core/_tests/test_guest_mode.py b/src/trio/_core/_tests/test_guest_mode.py index 81b7a07d8..3624a9d96 100644 --- a/src/trio/_core/_tests/test_guest_mode.py +++ b/src/trio/_core/_tests/test_guest_mode.py @@ -639,7 +639,8 @@ async def trio_main(in_host: InHost) -> None: with pytest.raises(KeyboardInterrupt) as excinfo: trivial_guest_run(trio_main) - assert excinfo.value.__context__ is None + assert isinstance(excinfo.value.__context__, trio.Cancelled) + assert excinfo.value.__context__.__context__ is None # Signal handler should be restored properly on exit assert signal.getsignal(signal.SIGINT) is signal.default_int_handler diff --git a/src/trio/_core/_tests/test_ki.py b/src/trio/_core/_tests/test_ki.py index 07a755872..5ff44ec80 100644 --- a/src/trio/_core/_tests/test_ki.py +++ b/src/trio/_core/_tests/test_ki.py @@ -313,8 +313,8 @@ async def check_unprotected_kill() -> None: _core.run(check_unprotected_kill) assert record_set == {"s1 ok", "s2 ok", "r1 raise ok"} - # simulated control-C during raiser, which is *protected*, so the KI gets - # delivered to the main task instead + # simulated control-C during raiser, which is *protected*, so the run + # gets cancelled instead. print("check 2") record_set = set() @@ -325,9 +325,13 @@ async def check_protected_kill() -> None: nursery.start_soon(_core.enable_ki_protection(raiser), "r1", record_set) # __aexit__ blocks, and then receives the KI - # raises inside a nursery, so the KeyboardInterrupt is wrapped in an ExceptionGroup - with RaisesGroup(KeyboardInterrupt): + # KeyboardInterrupt is inserted from the trio.run + with pytest.raises(KeyboardInterrupt) as excinfo: _core.run(check_protected_kill) + + # TODO: consider ensuring `__context__` is `None` in all cases above + # and below if the tree of `Cancelled`s is very spammy. + assert excinfo.value.__context__ is None assert record_set == {"s1 ok", "s2 ok", "r1 cancel ok"} # kill at last moment still raises (run_sync_soon until it raises an @@ -373,10 +377,11 @@ async def main_1() -> None: async def main_2() -> None: assert _core.currently_ki_protected() ki_self() - with pytest.raises(KeyboardInterrupt): + with pytest.raises(_core.Cancelled): await _core.checkpoint_if_cancelled() - _core.run(main_2) + with pytest.raises(KeyboardInterrupt): + _core.run(main_2) # KI arrives while main task is not abortable, b/c already scheduled print("check 6") @@ -388,10 +393,11 @@ async def main_3() -> None: await _core.cancel_shielded_checkpoint() await _core.cancel_shielded_checkpoint() await _core.cancel_shielded_checkpoint() - with pytest.raises(KeyboardInterrupt): + with pytest.raises(_core.Cancelled): await _core.checkpoint() - _core.run(main_3) + with pytest.raises(KeyboardInterrupt): + _core.run(main_3) # KI arrives while main task is not abortable, b/c refuses to be aborted print("check 7") @@ -407,10 +413,11 @@ def abort(_: RaiseCancelT) -> Abort: return _core.Abort.FAILED assert await _core.wait_task_rescheduled(abort) == 1 - with pytest.raises(KeyboardInterrupt): + with pytest.raises(_core.Cancelled): await _core.checkpoint() - _core.run(main_4) + with pytest.raises(KeyboardInterrupt): + _core.run(main_4) # KI delivered via slow abort print("check 8") @@ -426,11 +433,12 @@ def abort(raise_cancel: RaiseCancelT) -> Abort: _core.reschedule(task, result) return _core.Abort.FAILED - with pytest.raises(KeyboardInterrupt): + with pytest.raises(_core.Cancelled): assert await _core.wait_task_rescheduled(abort) await _core.checkpoint() - _core.run(main_5) + with pytest.raises(KeyboardInterrupt): + _core.run(main_5) # KI arrives just before main task exits, so the run_sync_soon machinery # is still functioning and will accept the callback to deliver the KI, but @@ -457,10 +465,11 @@ async def main_7() -> None: # ...but even after the KI, we keep running uninterrupted... record_list.append("ok") # ...until we hit a checkpoint: - with pytest.raises(KeyboardInterrupt): + with pytest.raises(_core.Cancelled): await sleep(10) - _core.run(main_7, restrict_keyboard_interrupt_to_checkpoints=True) + with pytest.raises(KeyboardInterrupt): + _core.run(main_7, restrict_keyboard_interrupt_to_checkpoints=True) assert record_list == ["ok"] record_list = [] # Exact same code raises KI early if we leave off the argument, doesn't @@ -469,25 +478,6 @@ async def main_7() -> None: _core.run(main_7) assert record_list == [] - # KI arrives while main task is inside a cancelled cancellation scope - # the KeyboardInterrupt should take priority - print("check 11") - - @_core.enable_ki_protection - async def main_8() -> None: - assert _core.currently_ki_protected() - with _core.CancelScope() as cancel_scope: - cancel_scope.cancel() - with pytest.raises(_core.Cancelled): - await _core.checkpoint() - ki_self() - with pytest.raises(KeyboardInterrupt): - await _core.checkpoint() - with pytest.raises(_core.Cancelled): - await _core.checkpoint() - - _core.run(main_8) - def test_ki_is_good_neighbor() -> None: # in the unlikely event someone overwrites our signal handler, we leave diff --git a/src/trio/_repl.py b/src/trio/_repl.py index 8be5af8fb..c5cc30596 100644 --- a/src/trio/_repl.py +++ b/src/trio/_repl.py @@ -12,14 +12,18 @@ import trio import trio.lowlevel +from trio._core._run_context import GLOBAL_RUN_CONTEXT from trio._util import final @final class TrioInteractiveConsole(InteractiveConsole): + runner: trio._core._run.Runner | None + def __init__(self, repl_locals: dict[str, object] | None = None) -> None: super().__init__(locals=repl_locals) self.compile.compiler.flags |= ast.PyCF_ALLOW_TOP_LEVEL_AWAIT + self.runner = None def runcode(self, code: types.CodeType) -> None: # https://github.com/python/typeshed/issues/13768 @@ -28,6 +32,17 @@ def runcode(self, code: types.CodeType) -> None: result = trio.from_thread.run(outcome.acapture, func) else: result = trio.from_thread.run_sync(outcome.capture, func) + + # clear ki_pending + assert self.runner is not None + ki_pending = self.runner.ki_pending + self.runner.ki_pending = False + + if ki_pending: + exc: BaseException | None = KeyboardInterrupt() + else: + exc = None + if isinstance(result, outcome.Error): # If it is SystemExit, quit the repl. Otherwise, print the traceback. # If there is a SystemExit inside a BaseExceptionGroup, it probably isn't @@ -37,21 +52,28 @@ def runcode(self, code: types.CodeType) -> None: if isinstance(result.error, SystemExit): raise result.error else: - # Inline our own version of self.showtraceback that can use - # outcome.Error.error directly to print clean tracebacks. - # This also means overriding self.showtraceback does nothing. - sys.last_type, sys.last_value = type(result.error), result.error - sys.last_traceback = result.error.__traceback__ - # see https://docs.python.org/3/library/sys.html#sys.last_exc - if sys.version_info >= (3, 12): - sys.last_exc = result.error + if exc: + exc.__context__ = result.error + else: + exc = result.error + + if exc is not None: + # Inline our own version of self.showtraceback that can use + # outcome.Error.error directly to print clean tracebacks. + # This also means overriding self.showtraceback does nothing. + sys.last_type, sys.last_value = type(exc), exc + sys.last_traceback = exc.__traceback__ + # see https://docs.python.org/3/library/sys.html#sys.last_exc + if sys.version_info >= (3, 12): + sys.last_exc = exc - # We always use sys.excepthook, unlike other implementations. - # This means that overriding self.write also does nothing to tbs. - sys.excepthook(sys.last_type, sys.last_value, sys.last_traceback) + # We always use sys.excepthook, unlike other implementations. + # This means that overriding self.write also does nothing to tbs. + sys.excepthook(sys.last_type, sys.last_value, sys.last_traceback) async def run_repl(console: TrioInteractiveConsole) -> None: + console.runner = GLOBAL_RUN_CONTEXT.runner banner = ( f"trio REPL {sys.version} on {sys.platform}\n" f'Use "await" directly instead of "trio.run()".\n'