Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 200 additions & 0 deletions Lib/test/test_with_signal_safety.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
"""Additional signal safety tests for "with" and "async with"
"""

from test.support import cpython_only, verbose
from _testcapi import install_error_injection_hook
import asyncio
import dis
import sys
import threading
import unittest

class InjectedException(Exception):
"""Exception injected into a running frame via a trace function"""
pass

def raise_after_offset(target_function, target_offset):
"""Sets a trace function to inject an exception into given function

Relies on the ability to request that a trace function be called for
every executed opcode, not just every line
"""
target_code = target_function.__code__
def inject_exception():
exc = InjectedException(f"Failing after {target_offset}")
raise exc
# This installs a trace hook that's implemented in C, and hence won't
# trigger any of the per-bytecode processing in the eval loop
# This means it can register the pending call that raises the exception and
# the pending call won't be processed until after the trace hook returns
install_error_injection_hook(target_code, target_offset, inject_exception)

# TODO: Add a test case that ensures raise_after_offset is working
# properly (otherwise there's a risk the tests will pass due to the
# exception not being injected properly)

@cpython_only
class CheckFunctionSignalSafety(unittest.TestCase):
"""Ensure with statements are signal-safe.

Signal safety means that, regardless of when external signals (e.g.
KeyboardInterrupt) are received, if __enter__ succeeds, __exit__ will
be called.

See https://bugs.python.org/issue29988 for more details
"""

def setUp(self):
old_trace = sys.gettrace()
self.addCleanup(sys.settrace, old_trace)
sys.settrace(None)

def assert_lock_released(self, test_lock, target_offset, traced_code):
just_acquired = test_lock.acquire(blocking=False)
# Either we just acquired the lock, or the test didn't release it
test_lock.release()
if not just_acquired:
msg = ("Context manager entered without exit due to "
f"exception injected at offset {target_offset} in:\n"
f"{dis.Bytecode(traced_code).dis()}")
self.fail(msg)

def _check_CM_exits_correctly(self, traced_function):
# Must use a signal-safe CM, otherwise __exit__ will start
# but then fail to actually run as the pending call gets processed
test_lock = threading.Lock()
target_offset = -1
traced_code = dis.Bytecode(traced_function)
for instruction in traced_code:
if instruction.opname == "RETURN_VALUE":
break
max_offset = instruction.offset
while target_offset < max_offset:
target_offset += 1
raise_after_offset(traced_function, target_offset)
try:
traced_function(test_lock)
except InjectedException:
# key invariant: if we entered the CM, we exited it
self.assert_lock_released(test_lock, target_offset, traced_code)
else:
try:
msg = (f"Exception wasn't raised @{target_offset} in:\n"
f"{traced_code.dis()}")
self.fail(msg)
except InjectedException:
# The pending call was still active when we tried to report
# the fact the exception wasn't raised by the traced function
msg = (f"Pending calls weren't processed after @{target_offset} in:\n"
f"{traced_code.dis()}")
self.fail(msg)

def test_with_statement_completed(self):
def traced_function(test_cm):
with test_cm:
1 + 1
return # Make implicit final return explicit
self._check_CM_exits_correctly(traced_function)

def test_with_statement_exited_via_return(self):
def traced_function(test_cm):
with test_cm:
1 + 1
return
return # Make implicit final return explicit
self._check_CM_exits_correctly(traced_function)

def test_with_statement_exited_via_continue(self):
def traced_function(test_cm):
for i in range(1):
with test_cm:
1 + 1
continue
return # Make implicit final return explicit
self._check_CM_exits_correctly(traced_function)

def test_with_statement_exited_via_break(self):
def traced_function(test_cm):
while True:
with test_cm:
1 + 1
break
return # Make implicit final return explicit
self._check_CM_exits_correctly(traced_function)

def test_with_statement_exited_via_raise(self):
def traced_function(test_cm):
try:
with test_cm:
1 + 1
1/0
except ZeroDivisionError:
pass
return # Make implicit final return explicit
self._check_CM_exits_correctly(traced_function)

@cpython_only
class CheckCoroutineSignalSafety(unittest.TestCase):
"""Ensure async with statements are signal-safe.

Signal safety means that, regardless of when external signals (e.g.
KeyboardInterrupt) are received, if __aenter__ succeeeds, __aexit__ will
be called *and* the resulting awaitable will be awaited.

See https://bugs.python.org/issue29988 for more details
"""

def setUp(self):
old_trace = sys.gettrace()
self.addCleanup(sys.settrace, old_trace)
sys.settrace(None)

def assert_CM_balanced(self, test_cm, target_offset, traced_code):
if test_cm.enter_without_exit:
msg = ("Context manager entered without exit due to "
f"exception injected at offset {target_offset} in:\n"
f"{traced_code.dis()}")
self.fail(msg)

def _check_CM_exits_correctly(self, traced_coroutine):
# NOTE: to get this to work, we also needed to update ceval to ensure
# that at least one line in a frame is executed before signals are
# processed (otherwise __aexit__'s body doesn't run)
class AsyncTrackingCM():
def __init__(self):
self.enter_without_exit = None
async def __aenter__(self):
self.enter_without_exit = True
async def __aexit__(self, *args):
self.enter_without_exit = False
test_cm = AsyncTrackingCM()
target_offset = -1
traced_code = dis.Bytecode(traced_coroutine)
for instruction in traced_code:
if instruction.opname == "RETURN_VALUE":
break
max_offset = instruction.offset
loop = asyncio.get_event_loop()
while target_offset < max_offset:
target_offset += 1
raise_after_offset(traced_coroutine, target_offset)
try:
loop.run_until_complete(traced_coroutine(test_cm))
except InjectedException:
# key invariant: if we entered the CM, we exited it
self.assert_CM_balanced(test_cm, target_offset, traced_code)
else:
msg = (f"Exception wasn't raised @{target_offset} in:\n"
f"{traced_code.dis()}")
self.fail(msg)

def test_async_with_statement_completed(self):
async def traced_coroutine(test_cm):
async with test_cm:
1 + 1
return # Make implicit final return explicit
self._check_CM_exits_correctly(traced_coroutine)


if __name__ == '__main__':
unittest.main()
50 changes: 50 additions & 0 deletions Modules/_testcapimodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "structmember.h"
#include "datetime.h"
#include "marshal.h"
#include "frameobject.h"
#include <signal.h>

#ifdef MS_WINDOWS
Expand Down Expand Up @@ -2345,6 +2346,53 @@ PyObject *pending_threadfunc(PyObject *self, PyObject *arg)
Py_RETURN_TRUE;
}

/* Helper for test_with_signal_safety that injects errors into the eval loop's
* pending call handling at a designated bytecode offset
*
* The hook args indicate the code object where the pending call should be
* injected, the offset where it should be registered, and the callback itself
*/
static int
error_injection_trace(PyObject *hook_args, PyFrameObject *frame,
int what, PyObject *event_arg)
{
PyObject *target_code, *callback;
int target_offset;

if (!PyArg_ParseTuple(hook_args, "OiO:error_injection_trace",
&target_code, &target_offset, &callback)) {
PyEval_SetTrace(NULL, NULL);
return -1;
}

if (((PyObject *) frame->f_code) == target_code) {
frame->f_trace_opcodes = 1;
if (what == PyTrace_OPCODE && frame->f_lasti >= target_offset) {
Py_INCREF(callback);
PyEval_SetTrace(NULL, NULL);
if (Py_AddPendingCall(&_pending_callback, callback) < 0) {
Py_DECREF(callback);
return -1;
}
}
}
return 0;
}

PyObject *install_error_injection_hook(PyObject *self, PyObject *args)
{
PyObject *target_code, *target_offset, *callback;

/* Check the args are as expected */
if (!PyArg_UnpackTuple(args, "install_error_injection_hook", 3, 3,
&target_code, &target_offset, &callback)) {
return NULL;
}

PyEval_SetTrace(error_injection_trace, args);
Py_RETURN_NONE;
}

/* Some tests of PyUnicode_FromFormat(). This needs more tests. */
static PyObject *
test_string_from_format(PyObject *self, PyObject *args)
Expand Down Expand Up @@ -4415,6 +4463,7 @@ static PyMethodDef TestMethods[] = {
{"unicode_legacy_string", unicode_legacy_string, METH_VARARGS},
{"_test_thread_state", test_thread_state, METH_VARARGS},
{"_pending_threadfunc", pending_threadfunc, METH_VARARGS},
{"install_error_injection_hook", install_error_injection_hook, METH_VARARGS},
#ifdef HAVE_GETTIMEOFDAY
{"profile_int", profile_int, METH_NOARGS},
#endif
Expand Down Expand Up @@ -4933,5 +4982,6 @@ PyInit__testcapi(void)
TestError = PyErr_NewException("_testcapi.error", NULL, NULL);
Py_INCREF(TestError);
PyModule_AddObject(m, "error", TestError);

return m;
}
Loading