diff --git a/Lib/test/test_with_signal_safety.py b/Lib/test/test_with_signal_safety.py new file mode 100644 index 00000000000000..c72a06c3157225 --- /dev/null +++ b/Lib/test/test_with_signal_safety.py @@ -0,0 +1,200 @@ +"""Additional signal safety tests for "with" and "async with" +""" + +from test.support import cpython_only, verbose +from _testcapi import install_error_injection_hook +import asyncio +import dis +import sys +import threading +import unittest + +class InjectedException(Exception): + """Exception injected into a running frame via a trace function""" + pass + +def raise_after_offset(target_function, target_offset): + """Sets a trace function to inject an exception into given function + + Relies on the ability to request that a trace function be called for + every executed opcode, not just every line + """ + target_code = target_function.__code__ + def inject_exception(): + exc = InjectedException(f"Failing after {target_offset}") + raise exc + # This installs a trace hook that's implemented in C, and hence won't + # trigger any of the per-bytecode processing in the eval loop + # This means it can register the pending call that raises the exception and + # the pending call won't be processed until after the trace hook returns + install_error_injection_hook(target_code, target_offset, inject_exception) + +# TODO: Add a test case that ensures raise_after_offset is working +# properly (otherwise there's a risk the tests will pass due to the +# exception not being injected properly) + +@cpython_only +class CheckFunctionSignalSafety(unittest.TestCase): + """Ensure with statements are signal-safe. + + Signal safety means that, regardless of when external signals (e.g. + KeyboardInterrupt) are received, if __enter__ succeeds, __exit__ will + be called. + + See https://bugs.python.org/issue29988 for more details + """ + + def setUp(self): + old_trace = sys.gettrace() + self.addCleanup(sys.settrace, old_trace) + sys.settrace(None) + + def assert_lock_released(self, test_lock, target_offset, traced_code): + just_acquired = test_lock.acquire(blocking=False) + # Either we just acquired the lock, or the test didn't release it + test_lock.release() + if not just_acquired: + msg = ("Context manager entered without exit due to " + f"exception injected at offset {target_offset} in:\n" + f"{dis.Bytecode(traced_code).dis()}") + self.fail(msg) + + def _check_CM_exits_correctly(self, traced_function): + # Must use a signal-safe CM, otherwise __exit__ will start + # but then fail to actually run as the pending call gets processed + test_lock = threading.Lock() + target_offset = -1 + traced_code = dis.Bytecode(traced_function) + for instruction in traced_code: + if instruction.opname == "RETURN_VALUE": + break + max_offset = instruction.offset + while target_offset < max_offset: + target_offset += 1 + raise_after_offset(traced_function, target_offset) + try: + traced_function(test_lock) + except InjectedException: + # key invariant: if we entered the CM, we exited it + self.assert_lock_released(test_lock, target_offset, traced_code) + else: + try: + msg = (f"Exception wasn't raised @{target_offset} in:\n" + f"{traced_code.dis()}") + self.fail(msg) + except InjectedException: + # The pending call was still active when we tried to report + # the fact the exception wasn't raised by the traced function + msg = (f"Pending calls weren't processed after @{target_offset} in:\n" + f"{traced_code.dis()}") + self.fail(msg) + + def test_with_statement_completed(self): + def traced_function(test_cm): + with test_cm: + 1 + 1 + return # Make implicit final return explicit + self._check_CM_exits_correctly(traced_function) + + def test_with_statement_exited_via_return(self): + def traced_function(test_cm): + with test_cm: + 1 + 1 + return + return # Make implicit final return explicit + self._check_CM_exits_correctly(traced_function) + + def test_with_statement_exited_via_continue(self): + def traced_function(test_cm): + for i in range(1): + with test_cm: + 1 + 1 + continue + return # Make implicit final return explicit + self._check_CM_exits_correctly(traced_function) + + def test_with_statement_exited_via_break(self): + def traced_function(test_cm): + while True: + with test_cm: + 1 + 1 + break + return # Make implicit final return explicit + self._check_CM_exits_correctly(traced_function) + + def test_with_statement_exited_via_raise(self): + def traced_function(test_cm): + try: + with test_cm: + 1 + 1 + 1/0 + except ZeroDivisionError: + pass + return # Make implicit final return explicit + self._check_CM_exits_correctly(traced_function) + +@cpython_only +class CheckCoroutineSignalSafety(unittest.TestCase): + """Ensure async with statements are signal-safe. + + Signal safety means that, regardless of when external signals (e.g. + KeyboardInterrupt) are received, if __aenter__ succeeeds, __aexit__ will + be called *and* the resulting awaitable will be awaited. + + See https://bugs.python.org/issue29988 for more details + """ + + def setUp(self): + old_trace = sys.gettrace() + self.addCleanup(sys.settrace, old_trace) + sys.settrace(None) + + def assert_CM_balanced(self, test_cm, target_offset, traced_code): + if test_cm.enter_without_exit: + msg = ("Context manager entered without exit due to " + f"exception injected at offset {target_offset} in:\n" + f"{traced_code.dis()}") + self.fail(msg) + + def _check_CM_exits_correctly(self, traced_coroutine): + # NOTE: to get this to work, we also needed to update ceval to ensure + # that at least one line in a frame is executed before signals are + # processed (otherwise __aexit__'s body doesn't run) + class AsyncTrackingCM(): + def __init__(self): + self.enter_without_exit = None + async def __aenter__(self): + self.enter_without_exit = True + async def __aexit__(self, *args): + self.enter_without_exit = False + test_cm = AsyncTrackingCM() + target_offset = -1 + traced_code = dis.Bytecode(traced_coroutine) + for instruction in traced_code: + if instruction.opname == "RETURN_VALUE": + break + max_offset = instruction.offset + loop = asyncio.get_event_loop() + while target_offset < max_offset: + target_offset += 1 + raise_after_offset(traced_coroutine, target_offset) + try: + loop.run_until_complete(traced_coroutine(test_cm)) + except InjectedException: + # key invariant: if we entered the CM, we exited it + self.assert_CM_balanced(test_cm, target_offset, traced_code) + else: + msg = (f"Exception wasn't raised @{target_offset} in:\n" + f"{traced_code.dis()}") + self.fail(msg) + + def test_async_with_statement_completed(self): + async def traced_coroutine(test_cm): + async with test_cm: + 1 + 1 + return # Make implicit final return explicit + self._check_CM_exits_correctly(traced_coroutine) + + +if __name__ == '__main__': + unittest.main() diff --git a/Modules/_testcapimodule.c b/Modules/_testcapimodule.c index 1a296214739c63..5a0b54b4a8f3d2 100644 --- a/Modules/_testcapimodule.c +++ b/Modules/_testcapimodule.c @@ -12,6 +12,7 @@ #include "structmember.h" #include "datetime.h" #include "marshal.h" +#include "frameobject.h" #include #ifdef MS_WINDOWS @@ -2345,6 +2346,53 @@ PyObject *pending_threadfunc(PyObject *self, PyObject *arg) Py_RETURN_TRUE; } +/* Helper for test_with_signal_safety that injects errors into the eval loop's + * pending call handling at a designated bytecode offset + * + * The hook args indicate the code object where the pending call should be + * injected, the offset where it should be registered, and the callback itself + */ +static int +error_injection_trace(PyObject *hook_args, PyFrameObject *frame, + int what, PyObject *event_arg) +{ + PyObject *target_code, *callback; + int target_offset; + + if (!PyArg_ParseTuple(hook_args, "OiO:error_injection_trace", + &target_code, &target_offset, &callback)) { + PyEval_SetTrace(NULL, NULL); + return -1; + } + + if (((PyObject *) frame->f_code) == target_code) { + frame->f_trace_opcodes = 1; + if (what == PyTrace_OPCODE && frame->f_lasti >= target_offset) { + Py_INCREF(callback); + PyEval_SetTrace(NULL, NULL); + if (Py_AddPendingCall(&_pending_callback, callback) < 0) { + Py_DECREF(callback); + return -1; + } + } + } + return 0; +} + +PyObject *install_error_injection_hook(PyObject *self, PyObject *args) +{ + PyObject *target_code, *target_offset, *callback; + + /* Check the args are as expected */ + if (!PyArg_UnpackTuple(args, "install_error_injection_hook", 3, 3, + &target_code, &target_offset, &callback)) { + return NULL; + } + + PyEval_SetTrace(error_injection_trace, args); + Py_RETURN_NONE; +} + /* Some tests of PyUnicode_FromFormat(). This needs more tests. */ static PyObject * test_string_from_format(PyObject *self, PyObject *args) @@ -4415,6 +4463,7 @@ static PyMethodDef TestMethods[] = { {"unicode_legacy_string", unicode_legacy_string, METH_VARARGS}, {"_test_thread_state", test_thread_state, METH_VARARGS}, {"_pending_threadfunc", pending_threadfunc, METH_VARARGS}, + {"install_error_injection_hook", install_error_injection_hook, METH_VARARGS}, #ifdef HAVE_GETTIMEOFDAY {"profile_int", profile_int, METH_NOARGS}, #endif @@ -4933,5 +4982,6 @@ PyInit__testcapi(void) TestError = PyErr_NewException("_testcapi.error", NULL, NULL); Py_INCREF(TestError); PyModule_AddObject(m, "error", TestError); + return m; } diff --git a/Python/ceval.c b/Python/ceval.c index 5dd7cd9f03e35f..ee08be030e5013 100644 --- a/Python/ceval.c +++ b/Python/ceval.c @@ -92,6 +92,7 @@ static long dxp[256]; #endif #define GIL_REQUEST _Py_atomic_load_relaxed(&_PyRuntime.ceval.gil_drop_request) +#define _Py_PendingEvalState _PyRuntime.ceval.pending /* This can set eval_breaker to 0 even though gil_drop_request became 1. We believe this is all right because the eval loop will release @@ -100,8 +101,8 @@ static long dxp[256]; _Py_atomic_store_relaxed( \ &_PyRuntime.ceval.eval_breaker, \ GIL_REQUEST | \ - _Py_atomic_load_relaxed(&_PyRuntime.ceval.pending.calls_to_do) | \ - _PyRuntime.ceval.pending.async_exc) + _Py_atomic_load_relaxed(&_Py_PendingEvalState.calls_to_do) | \ + _Py_PendingEvalState.async_exc) #define SET_GIL_DROP_REQUEST() \ do { \ @@ -118,25 +119,25 @@ static long dxp[256]; /* Pending calls are only modified under pending_lock */ #define SIGNAL_PENDING_CALLS() \ do { \ - _Py_atomic_store_relaxed(&_PyRuntime.ceval.pending.calls_to_do, 1); \ + _Py_atomic_store_relaxed(&_Py_PendingEvalState.calls_to_do, 1); \ _Py_atomic_store_relaxed(&_PyRuntime.ceval.eval_breaker, 1); \ } while (0) #define UNSIGNAL_PENDING_CALLS() \ do { \ - _Py_atomic_store_relaxed(&_PyRuntime.ceval.pending.calls_to_do, 0); \ + _Py_atomic_store_relaxed(&_Py_PendingEvalState.calls_to_do, 0); \ COMPUTE_EVAL_BREAKER(); \ } while (0) #define SIGNAL_ASYNC_EXC() \ do { \ - _PyRuntime.ceval.pending.async_exc = 1; \ + _Py_PendingEvalState.async_exc = 1; \ _Py_atomic_store_relaxed(&_PyRuntime.ceval.eval_breaker, 1); \ } while (0) #define UNSIGNAL_ASYNC_EXC() \ do { \ - _PyRuntime.ceval.pending.async_exc = 0; \ + _Py_PendingEvalState.async_exc = 0; \ COMPUTE_EVAL_BREAKER(); \ } while (0) @@ -160,9 +161,9 @@ PyEval_InitThreads(void) return; create_gil(); take_gil(PyThreadState_GET()); - _PyRuntime.ceval.pending.main_thread = PyThread_get_thread_ident(); - if (!_PyRuntime.ceval.pending.lock) - _PyRuntime.ceval.pending.lock = PyThread_allocate_lock(); + _Py_PendingEvalState.main_thread = PyThread_get_thread_ident(); + if (!_Py_PendingEvalState.lock) + _Py_PendingEvalState.lock = PyThread_allocate_lock(); } void @@ -230,9 +231,9 @@ PyEval_ReInitThreads(void) if (!gil_created()) return; recreate_gil(); - _PyRuntime.ceval.pending.lock = PyThread_allocate_lock(); + _Py_PendingEvalState.lock = PyThread_allocate_lock(); take_gil(current_tstate); - _PyRuntime.ceval.pending.main_thread = PyThread_get_thread_ident(); + _Py_PendingEvalState.main_thread = PyThread_get_thread_ident(); /* Destroy all threads except the current one */ _PyThreadState_DeleteExcept(current_tstate); @@ -322,7 +323,7 @@ int Py_AddPendingCall(int (*func)(void *), void *arg) { int i, j, result=0; - PyThread_type_lock lock = _PyRuntime.ceval.pending.lock; + PyThread_type_lock lock = _Py_PendingEvalState.lock; /* try a few times for the lock. Since this mechanism is used * for signal handling (on the main thread), there is a (slim) @@ -344,14 +345,14 @@ Py_AddPendingCall(int (*func)(void *), void *arg) return -1; } - i = _PyRuntime.ceval.pending.last; + i = _Py_PendingEvalState.last; j = (i + 1) % NPENDINGCALLS; - if (j == _PyRuntime.ceval.pending.first) { + if (j == _Py_PendingEvalState.first) { result = -1; /* Queue full */ } else { - _PyRuntime.ceval.pending.calls[i].func = func; - _PyRuntime.ceval.pending.calls[i].arg = arg; - _PyRuntime.ceval.pending.last = j; + _Py_PendingEvalState.calls[i].func = func; + _Py_PendingEvalState.calls[i].arg = arg; + _Py_PendingEvalState.last = j; } /* signal main loop */ SIGNAL_PENDING_CALLS(); @@ -369,16 +370,16 @@ Py_MakePendingCalls(void) assert(PyGILState_Check()); - if (!_PyRuntime.ceval.pending.lock) { + if (!_Py_PendingEvalState.lock) { /* initial allocation of the lock */ - _PyRuntime.ceval.pending.lock = PyThread_allocate_lock(); - if (_PyRuntime.ceval.pending.lock == NULL) + _Py_PendingEvalState.lock = PyThread_allocate_lock(); + if (_Py_PendingEvalState.lock == NULL) return -1; } /* only service pending calls on main thread */ - if (_PyRuntime.ceval.pending.main_thread && - PyThread_get_thread_ident() != _PyRuntime.ceval.pending.main_thread) + if (_Py_PendingEvalState.main_thread && + PyThread_get_thread_ident() != _Py_PendingEvalState.main_thread) { return 0; } @@ -403,16 +404,16 @@ Py_MakePendingCalls(void) void *arg = NULL; /* pop one item off the queue while holding the lock */ - PyThread_acquire_lock(_PyRuntime.ceval.pending.lock, WAIT_LOCK); - j = _PyRuntime.ceval.pending.first; - if (j == _PyRuntime.ceval.pending.last) { + PyThread_acquire_lock(_Py_PendingEvalState.lock, WAIT_LOCK); + j = _Py_PendingEvalState.first; + if (j == _Py_PendingEvalState.last) { func = NULL; /* Queue empty */ } else { - func = _PyRuntime.ceval.pending.calls[j].func; - arg = _PyRuntime.ceval.pending.calls[j].arg; - _PyRuntime.ceval.pending.first = (j + 1) % NPENDINGCALLS; + func = _Py_PendingEvalState.calls[j].func; + arg = _Py_PendingEvalState.calls[j].arg; + _Py_PendingEvalState.first = (j + 1) % NPENDINGCALLS; } - PyThread_release_lock(_PyRuntime.ceval.pending.lock); + PyThread_release_lock(_Py_PendingEvalState.lock); /* having released the lock, perform the callback */ if (func == NULL) break; @@ -557,6 +558,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag) #endif PyObject **stack_pointer; /* Next free slot in value stack */ const _Py_CODEUNIT *next_instr; + const _Py_CODEUNIT *current_instr; /* Reliably detect loops jumping back */ int opcode; /* Current opcode */ int oparg; /* Current opcode argument, if any */ enum why_code why; /* Reason for block stack unwind */ @@ -657,6 +659,10 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag) #ifdef LLTRACE #define FAST_DISPATCH() \ { \ + if (_Py_OPCODE(*next_instr) == RETURN_VALUE || \ + _Py_OPCODE(*next_instr) == POP_BLOCK) { \ + continue; /* Always check signals before returning */ \ + } \ if (!lltrace && !_Py_TracingPossible && !PyDTrace_LINE_ENABLED()) { \ f->f_lasti = INSTR_OFFSET(); \ NEXTOPARG(); \ @@ -667,6 +673,10 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag) #else #define FAST_DISPATCH() \ { \ + if (_Py_OPCODE(*next_instr) == RETURN_VALUE || \ + _Py_OPCODE(*next_instr) == POP_BLOCK) { \ + continue; /* Always check signals before returning */ \ + } \ if (!_Py_TracingPossible && !PyDTrace_LINE_ENABLED()) { \ f->f_lasti = INSTR_OFFSET(); \ NEXTOPARG(); \ @@ -702,6 +712,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag) _Py_CODEUNIT word = *next_instr; \ opcode = _Py_OPCODE(word); \ oparg = _Py_OPARG(word); \ + current_instr = next_instr; \ next_instr++; \ } while (0) #define JUMPTO(x) (next_instr = first_instr + (x) / sizeof(_Py_CODEUNIT)) @@ -742,6 +753,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag) opcode = _Py_OPCODE(word); \ if (opcode == op){ \ oparg = _Py_OPARG(word); \ + current_instr = next_instr; next_instr++; \ goto PRED_##op; \ } \ @@ -905,6 +917,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag) assert(f->f_lasti % sizeof(_Py_CODEUNIT) == 0); next_instr += f->f_lasti / sizeof(_Py_CODEUNIT) + 1; } + current_instr = next_instr; stack_pointer = f->f_stacktop; assert(stack_pointer != NULL); f->f_stacktop = NULL; /* remains NULL unless yield suspends frame */ @@ -951,31 +964,53 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag) Py_MakePendingCalls() above. */ if (_Py_atomic_load_relaxed(&_PyRuntime.ceval.eval_breaker)) { - if (_Py_OPCODE(*next_instr) == SETUP_FINALLY || - _Py_OPCODE(*next_instr) == YIELD_FROM) { - /* Two cases where we skip running signal handlers and other - pending calls: - - If we're about to enter the try: of a try/finally (not - *very* useful, but might help in some cases and it's - traditional) - - If we're resuming a chain of nested 'yield from' or - 'await' calls, then each frame is parked with YIELD_FROM - as its next opcode. If the user hit control-C we want to - wait until we've reached the innermost frame before - running the signal handler and raising KeyboardInterrupt - (see bpo-30039). + if (_Py_OPCODE(*next_instr) == YIELD_FROM) { + /* If we're resuming a chain of nested 'yield from' or + 'await' calls, then each frame is parked with YIELD_FROM + as its next opcode. If the user hit control-C we want to + wait until we've reached the innermost frame before + running the signal handler and raising KeyboardInterrupt + (see bpo-30039). */ goto fast_next_opcode; } - if (_Py_atomic_load_relaxed( - &_PyRuntime.ceval.pending.calls_to_do)) - { - if (Py_MakePendingCalls() < 0) + /* We check for pending calls & async exceptions if the next + * instruction is *earlier* in the compiled code than the current + * opcode (as this indicates jumping back in a loop) + * + * We also ensure we check for signals if we're about to *return* + * from a function, or are about to pop the jump target block + * for a loop or try/except/finally statement. + * + * These locations are generally safe from interfering with + * the correct execution of with and try/finally statements. + * + * We deliberately *don't* check for exceptions when *starting* + * a call, as this can lead to __exit__ methods unexpectedly failing, + * while checking at the end provides opportunity for __enter__ + * methods to switch off the check in that frame. + * (see bpo-29988). + */ + if (next_instr < current_instr || + _Py_OPCODE(*next_instr) == RETURN_VALUE || + _Py_OPCODE(*next_instr) == POP_BLOCK) { + /* Check for pending calls */ + if (_Py_atomic_load_relaxed(&_Py_PendingEvalState.calls_to_do)) { + if (Py_MakePendingCalls() < 0) + goto error; + } + /* Check for asynchronous exceptions. */ + if (tstate->async_exc != NULL) { + PyObject *exc = tstate->async_exc; + tstate->async_exc = NULL; + UNSIGNAL_ASYNC_EXC(); + PyErr_SetNone(exc); + Py_DECREF(exc); goto error; + } } - if (_Py_atomic_load_relaxed( - &_PyRuntime.ceval.gil_drop_request)) - { + /* Check if we should surrender the GIL to another thread */ + if (_Py_atomic_load_relaxed(&_PyRuntime.ceval.gil_drop_request)) { /* Give another thread a chance */ if (PyThreadState_Swap(NULL) != tstate) Py_FatalError("ceval: tstate mix-up"); @@ -996,15 +1031,6 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag) if (PyThreadState_Swap(tstate) != NULL) Py_FatalError("ceval: orphan tstate"); } - /* Check for asynchronous exceptions. */ - if (tstate->async_exc != NULL) { - PyObject *exc = tstate->async_exc; - tstate->async_exc = NULL; - UNSIGNAL_ASYNC_EXC(); - PyErr_SetNone(exc); - Py_DECREF(exc); - goto error; - } } fast_next_opcode: @@ -2840,18 +2866,7 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag) PREDICTED(JUMP_ABSOLUTE); TARGET(JUMP_ABSOLUTE) { JUMPTO(oparg); -#if FAST_LOOPS - /* Enabling this path speeds-up all while and for-loops by bypassing - the per-loop checks for signals. By default, this should be turned-off - because it prevents detection of a control-break in tight loops like - "while 1: pass". Compile with this option turned-on when you need - the speed-up and do not need break checking inside tight loops (ones - that contain only instructions ending with FAST_DISPATCH). - */ - FAST_DISPATCH(); -#else DISPATCH(); -#endif } TARGET(GET_ITER) { @@ -5119,8 +5134,10 @@ unicode_concatenate(PyObject *v, PyObject *w, * 'variable'. We try to delete the variable now to reduce * the refcnt to 1. */ + _Py_CODEUNIT word = *next_instr; int opcode, oparg; - NEXTOPARG(); + opcode = _Py_OPCODE(word); + oparg = _Py_OPARG(word); switch (opcode) { case STORE_FAST: { diff --git a/Python/compile.c b/Python/compile.c index e547c2fd591c49..57bd0a4c51e816 100644 --- a/Python/compile.c +++ b/Python/compile.c @@ -4211,7 +4211,6 @@ compiler_async_with(struct compiler *c, stmt_ty s, int pos) ADDOP(c, GET_AWAITABLE); ADDOP_O(c, LOAD_CONST, Py_None, consts); ADDOP(c, YIELD_FROM); - ADDOP_JREL(c, SETUP_ASYNC_WITH, finally); /* SETUP_ASYNC_WITH pushes a finally block. */