Skip to content

Commit 9b8891a

Browse files
committed
Use a signal-safe CM in the synchronous test case
- without this, the exception gets raised as soon as the CM's __exit__ method starts running - for the async case, it turns out there's nothing we can do to solve this at the byte code level. Instead, we need to somehow push *all* pending call processing back to the event loop and inject a synthetic await into the current frame
1 parent a66a72e commit 9b8891a

File tree

3 files changed

+31
-22
lines changed

3 files changed

+31
-22
lines changed

Lib/test/test_with_signal_safety.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import asyncio
77
import dis
88
import sys
9+
import threading
910
import unittest
1011

1112
class InjectedException(Exception):
@@ -20,8 +21,9 @@ def raise_after_offset(target_function, target_offset):
2021
"""
2122
target_code = target_function.__code__
2223
def inject_exception():
23-
print("Raising injected exception")
24-
raise InjectedException(f"Failing after {target_offset}")
24+
exc = InjectedException(f"Failing after {target_offset}")
25+
print(f"Raising injected exception: {exc}")
26+
raise exc
2527
# This installs a trace hook that's implemented in C, and hence won't
2628
# trigger any of the per-bytecode processing in the eval loop
2729
# This means it can register the pending call that raises the exception and
@@ -51,24 +53,22 @@ def setUp(self):
5153
self.addCleanup(sys.settrace, old_trace)
5254
sys.settrace(None)
5355

54-
def assert_cm_exited(self, tracking_cm, target_offset, traced_operation):
55-
if tracking_cm.enter_without_exit:
56+
def assert_lock_released(self, test_lock, target_offset, traced_operation):
57+
just_acquired = test_lock.acquire(blocking=False)
58+
# Either we just acquired the lock, or the test didn't release it
59+
test_lock.release()
60+
if not just_acquired:
5661
msg = ("Context manager entered without exit due to "
5762
f"exception injected at offset {target_offset} in:\n"
5863
f"{dis.Bytecode(traced_operation).dis()}")
5964
self.fail(msg)
6065

6166
def test_synchronous_cm(self):
62-
class TrackingCM():
63-
def __init__(self):
64-
self.enter_without_exit = None
65-
def __enter__(self):
66-
self.enter_without_exit = True
67-
def __exit__(self, *args):
68-
self.enter_without_exit = False
69-
tracking_cm = TrackingCM()
67+
# Must use a signal-safe CM, otherwise __exit__ will start
68+
# but then fail to actually run as the pending call gets processed
69+
test_lock = threading.Lock()
7070
def traced_function():
71-
with tracking_cm:
71+
with test_lock:
7272
1 + 1
7373
return
7474
target_offset = -1
@@ -80,12 +80,20 @@ def traced_function():
8080
traced_function()
8181
except InjectedException:
8282
# key invariant: if we entered the CM, we exited it
83-
self.assert_cm_exited(tracking_cm, target_offset, traced_function)
83+
self.assert_lock_released(test_lock, target_offset, traced_function)
8484
else:
8585
self.fail(f"Exception wasn't raised @{target_offset}")
8686

8787

88-
def test_asynchronous_cm(self):
88+
def _test_asynchronous_cm(self):
89+
# NOTE: this can't work, since asyncio is written in Python, and hence
90+
# will always process pending calls at some point during the evaluation
91+
# of __aenter__ and __aexit__
92+
#
93+
# So to handle that case, we need to some way to tell the event loop
94+
# to convert pending call processing into calls to
95+
# asyncio.get_event_loop().call_soon() instead of processing them
96+
# immediately
8997
class AsyncTrackingCM():
9098
def __init__(self):
9199
self.enter_without_exit = None
@@ -108,7 +116,7 @@ async def traced_coroutine():
108116
loop.run_until_complete(traced_coroutine())
109117
except InjectedException:
110118
# key invariant: if we entered the CM, we exited it
111-
self.assert_cm_exited(tracking_cm, target_offset, traced_coroutine)
119+
self.assertFalse(tracking_cm.enter_without_exit)
112120
else:
113121
self.fail(f"Exception wasn't raised @{target_offset}")
114122

Modules/_testcapimodule.c

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2368,16 +2368,15 @@ error_injection_trace(PyObject *hook_args, PyFrameObject *frame,
23682368
if (((PyObject *) frame->f_code) == target_code) {
23692369
printf("Tracing frame of interest\n");
23702370
frame->f_trace_opcodes = 1;
2371-
if (what == PyTrace_OPCODE && frame->f_lasti > target_offset) {
2372-
printf("Adding pending call after %d\n", frame->f_lasti);
2371+
if (what == PyTrace_OPCODE && frame->f_lasti >= target_offset) {
23732372
Py_INCREF(callback);
2373+
PyEval_SetTrace(NULL, NULL);
2374+
printf("Adding pending call after %d\n", frame->f_lasti);
23742375
if (Py_AddPendingCall(&_pending_callback, callback) < 0) {
23752376
printf("Failed to add pending call\n");
23762377
Py_DECREF(callback);
2377-
PyEval_SetTrace(NULL, NULL);
23782378
return -1;
23792379
}
2380-
PyEval_SetTrace(NULL, NULL);
23812380
}
23822381
}
23832382
return 0;

Python/ceval.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,11 +419,13 @@ Py_MakePendingCalls(void)
419419
PyThread_acquire_lock(pending_lock, WAIT_LOCK);
420420
j = pendingfirst;
421421
if (j == pendinglast) {
422+
printf(" No pending calls remaining\n");
422423
func = NULL; /* Queue empty */
423424
} else {
424425
func = pendingcalls[j].func;
425426
arg = pendingcalls[j].arg;
426427
pendingfirst = (j + 1) % NPENDINGCALLS;
428+
printf(" Calling %p(%p)\n", func, arg);
427429
}
428430
PyThread_release_lock(pending_lock);
429431
/* having released the lock, perform the callback */
@@ -982,10 +984,10 @@ _PyEval_EvalFrameDefault(PyFrameObject *f, int throwflag)
982984
goto fast_next_opcode;
983985
}
984986
if (handle_pending_after > first_instr) {
985-
printf("Pending calls deferred until %lu\n", DEFER_OFFSET());
987+
printf("Pending calls deferred until %lu (current: %lu)\n", DEFER_OFFSET(), INSTR_OFFSET());
986988
}
987989
if (next_instr >= handle_pending_after) {
988-
printf("Checking for pending calls: %lu > %lu?\n", INSTR_OFFSET(), DEFER_OFFSET());
990+
printf("Checking for pending calls: %lu >= %lu?\n", INSTR_OFFSET(), DEFER_OFFSET());
989991
/* Allow for subsequent jumps backwards in the bytecode */
990992
handle_pending_after = first_instr;
991993
if (_Py_atomic_load_relaxed(&pendingcalls_to_do)) {

0 commit comments

Comments
 (0)