Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Doc/c-api/contextvars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ Context object management functions:
- ``Py_CONTEXT_SWITCHED``: The :term:`current context` has switched to a
different context. The object passed to the watch callback is the
now-current :class:`contextvars.Context` object, or None if no context is
current.
current. The thread executing the callback is guaranteed to be the thread
that experienced the context switch.

.. versionadded:: 3.14

Expand Down
4 changes: 3 additions & 1 deletion Include/cpython/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ typedef enum {
/*
* The current context has switched to a different context. The object
* passed to the watch callback is the now-current contextvars.Context
* object, or None if no context is current.
* object, or None if no context is current. The thread executing the
* callback is guaranteed to be the thread that experienced the context
* switch.
*/
Py_CONTEXT_SWITCHED = 1,
} PyContextEvent;
Expand Down
15 changes: 14 additions & 1 deletion Include/internal/pycore_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ extern PyTypeObject _PyContextTokenMissing_Type;

PyStatus _PyContext_Init(PyInterpreterState *);

// Exits any thread-owned contexts (see context_get) at the top of the given
// thread's context stack. The given thread state is not required to belong to
// the calling thread; if not, the thread is assumed to have exited (or not yet
// started) and no Py_CONTEXT_SWITCHED event is emitted for any context
// changes. Logs a warning via PyErr_FormatUnraisable if the thread's context
// stack is non-empty afterwards (because those contexts can never be exited or
// re-entered).
void _PyContext_ExitThreadOwned(PyThreadState *);


/* other API */

Expand All @@ -27,7 +36,11 @@ struct _pycontextobject {
PyContext *ctx_prev;
PyHamtObject *ctx_vars;
PyObject *ctx_weakreflist;
int ctx_entered;
_Bool ctx_entered:1;
// True for the thread's default context created by context_get. Used to
// safely determine whether the base context can be exited when clearing a
// PyThreadState.
_Bool ctx_owned_by_thread:1;
};


Expand Down
73 changes: 72 additions & 1 deletion Lib/test/test_capi/test_watchers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import threading
import unittest
import contextvars

from contextlib import contextmanager, ExitStack
from test.support import (
catch_unraisable_exception, import_helper,
gc_collect)
gc_collect, threading_helper)


# Skip this test if the _testcapi module isn't available.
Expand Down Expand Up @@ -674,5 +675,75 @@ def test_exit_base_context(self):
ctx.run(lambda: None)
self.assertEqual(switches, [ctx, None])

def test_reenter_default_context(self):
_testcapi.clear_context_stack()
# contextvars.copy_context() creates the thread's default context (via
# the context_get C function).
ctx = contextvars.copy_context()
with self.context_watcher(0) as switches:
ctx.run(lambda: None)
self.assertEqual(len(switches), 2)
self.assertEqual(switches[0], ctx)
base_ctx = switches[1]
self.assertIsNotNone(base_ctx)
self.assertIsNot(base_ctx, ctx)
with self.assertRaisesRegex(RuntimeError, 'already entered'):
base_ctx.run(lambda: None)

def test_default_context_enter(self):
_testcapi.clear_context_stack()
with self.context_watcher(0) as switches:
ctx = contextvars.copy_context()
ctx.run(lambda: None)
self.assertEqual(len(switches), 3)
base_ctx = switches[0]
self.assertIsNotNone(base_ctx)
self.assertEqual(switches, [base_ctx, ctx, base_ctx])

@threading_helper.requires_working_threading()
def test_default_context_exit_during_thread_cleanup(self):
# Context watchers are per-interpreter, not per-thread.
with self.context_watcher(0) as switches:
def _thread_main():
_testcapi.clear_context_stack()
# contextvars.copy_context() creates the thread's default
# context (via the context_get C function).
contextvars.copy_context()
# This test only cares about the final switch that happens when
# exiting the thread's default context during thread cleanup.
switches.clear()

thread = threading.Thread(target=_thread_main)
thread.start()
threading_helper.join_thread(thread)
self.assertEqual(switches, [None])

@threading_helper.requires_working_threading()
def test_thread_cleanup_with_entered_context(self):
unraisables = []
try:
with catch_unraisable_exception() as cm:
with self.context_watcher(0) as switches:
def _thread_main():
_testcapi.clear_context_stack()
ctx = contextvars.copy_context()
_testcapi.context_enter(ctx)
switches.clear()

thread = threading.Thread(target=_thread_main)
thread.start()
threading_helper.join_thread(thread)
unraisables.append(cm.unraisable)
self.assertEqual(switches, [])
self.assertEqual(len(unraisables), 1)
self.assertIsNotNone(unraisables[0])
self.assertRegex(unraisables[0].err_msg,
r'^Exception ignored during reset of thread state')
self.assertRegex(str(unraisables[0].exc_value), r'still entered')
finally:
# Break reference cycle
unraisables = None


if __name__ == "__main__":
unittest.main()
10 changes: 10 additions & 0 deletions Modules/_testcapi/watchers.c
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,15 @@ clear_context_stack(PyObject *Py_UNUSED(self), PyObject *Py_UNUSED(args))
Py_RETURN_NONE;
}

static PyObject *
context_enter(PyObject *self, PyObject *ctx)
{
if (PyContext_Enter(ctx)) {
return NULL;
}
Py_RETURN_NONE;
}

static PyObject *
get_context_switches(PyObject *Py_UNUSED(self), PyObject *watcher_id)
{
Expand Down Expand Up @@ -841,6 +850,7 @@ static PyMethodDef test_methods[] = {
{"add_context_watcher", add_context_watcher, METH_O, NULL},
{"clear_context_watcher", clear_context_watcher, METH_O, NULL},
{"clear_context_stack", clear_context_stack, METH_NOARGS, NULL},
{"context_enter", context_enter, METH_O, NULL},
{"get_context_switches", get_context_switches, METH_O, NULL},
{"allocate_too_many_context_watchers",
(PyCFunction) allocate_too_many_context_watchers, METH_NOARGS, NULL},
Expand Down
86 changes: 72 additions & 14 deletions Python/context.c
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ context_event_name(PyContextEvent event) {
static void
notify_context_watchers(PyThreadState *ts, PyContextEvent event, PyObject *ctx)
{
// The callbacks are registered on the interpreter, not on the thread, so
// the only way callbacks can know which thread changed is by calling the
// callbacks from the affected thread.
assert(ts == _PyThreadState_GET());
if (ctx == NULL) {
// This will happen after exiting the last context in the stack, which
// can occur if context_get was never called before entering a context
Expand Down Expand Up @@ -184,29 +188,30 @@ static inline void
context_switched(PyThreadState *ts)
{
ts->context_ver++;
// ts->context is used instead of context_get() because context_get() might
// throw if ts->context is NULL.
// ts->context is used instead of context_get() because if ts->context is
// NULL, context_get() will either call context_switched -- causing a
// double notification -- or throw.
notify_context_watchers(ts, Py_CONTEXT_SWITCHED, ts->context);
}


// ts is not required to belong to the calling thread.
static int
_PyContext_Enter(PyThreadState *ts, PyObject *octx)
{
ENSURE_Context(octx, -1)
PyContext *ctx = (PyContext *)octx;

if (ctx->ctx_entered) {
_PyErr_Format(ts, PyExc_RuntimeError,
"cannot enter context: %R is already entered", ctx);
PyErr_Format(PyExc_RuntimeError,
"cannot enter context: %R is already entered", ctx);
return -1;
}

ctx->ctx_prev = (PyContext *)ts->context; /* borrow */
ctx->ctx_entered = 1;

ts->context = Py_NewRef(ctx);
context_switched(ts);
return 0;
}

Expand All @@ -216,10 +221,15 @@ PyContext_Enter(PyObject *octx)
{
PyThreadState *ts = _PyThreadState_GET();
assert(ts != NULL);
return _PyContext_Enter(ts, octx);
if (_PyContext_Enter(ts, octx)) {
return -1;
}
context_switched(ts);
return 0;
}


// ts is not required to belong to the calling thread.
static int
_PyContext_Exit(PyThreadState *ts, PyObject *octx)
{
Expand All @@ -244,7 +254,7 @@ _PyContext_Exit(PyThreadState *ts, PyObject *octx)

ctx->ctx_prev = NULL;
ctx->ctx_entered = 0;
context_switched(ts);
ctx->ctx_owned_by_thread = 0;
return 0;
}

Expand All @@ -253,7 +263,49 @@ PyContext_Exit(PyObject *octx)
{
PyThreadState *ts = _PyThreadState_GET();
assert(ts != NULL);
return _PyContext_Exit(ts, octx);
if (_PyContext_Exit(ts, octx)) {
return -1;
}
context_switched(ts);
return 0;
}


void
_PyContext_ExitThreadOwned(PyThreadState *ts)
{
assert(ts != NULL);
while (ts->context != NULL
&& PyContext_CheckExact(ts->context)
&& ((PyContext *)ts->context)->ctx_owned_by_thread) {
if (_PyContext_Exit(ts, ts->context)) {
// Exiting a context that is already known to be at the top of the
// stack cannot fail.
Py_UNREACHABLE();
}
// notify_context_watchers() requires the notification to come from the
// affected thread, so context_switched() must not be called if ts
// doesn't belong to the current thread. However, it's OK to skip
// calling it in this case: this function is only called when resetting
// a PyThreadState, so if the calling thread doesn't own ts, then the
// owning thread must not be running anymore (it must have just
// finished because a thread-owned context exists here).
if (ts == _PyThreadState_GET()) {
context_switched(ts);
}
}
if (ts->context != NULL) {
// This intentionally does not use tstate variants of these functions
// (e.g., _PyErr_GetRaisedException(ts)) because ts might not belong to
// the current thread.
PyObject *exc = PyErr_GetRaisedException();
PyErr_SetString(PyExc_RuntimeError,
"contextvars.Context object(s) still entered during "
"thread state reset");
PyErr_FormatUnraisable(
"Exception ignored during reset of thread state");
PyErr_SetRaisedException(exc);
}
}


Expand Down Expand Up @@ -433,6 +485,7 @@ _context_alloc(void)
ctx->ctx_vars = NULL;
ctx->ctx_prev = NULL;
ctx->ctx_entered = 0;
ctx->ctx_owned_by_thread = 0;
ctx->ctx_weakreflist = NULL;

return ctx;
Expand Down Expand Up @@ -478,15 +531,18 @@ context_get(void)
{
PyThreadState *ts = _PyThreadState_GET();
assert(ts != NULL);
PyContext *current_ctx = (PyContext *)ts->context;
if (current_ctx == NULL) {
current_ctx = context_new_empty();
if (current_ctx == NULL) {
if (ts->context == NULL) {
PyContext *ctx = context_new_empty();
if (ctx == NULL || _PyContext_Enter(ts, (PyObject *)ctx)) {
return NULL;
}
ts->context = (PyObject *)current_ctx;
ctx->ctx_owned_by_thread = 1;
assert(ts->context == (PyObject *)ctx);
Py_CLEAR(ctx); // _PyContext_Enter created its own ref.
context_switched(ts);
}
return current_ctx;
assert(PyContext_CheckExact(ts->context));
return (PyContext *)ts->context;
}

static int
Expand Down Expand Up @@ -715,6 +771,7 @@ context_run(PyContext *self, PyObject *const *args,
if (_PyContext_Enter(ts, (PyObject *)self)) {
return NULL;
}
context_switched(ts);

PyObject *call_result = _PyObject_VectorcallTstate(
ts, args[0], args + 1, nargs - 1, kwnames);
Expand All @@ -723,6 +780,7 @@ context_run(PyContext *self, PyObject *const *args,
Py_XDECREF(call_result);
return NULL;
}
context_switched(ts);

return call_result;
}
Expand Down
4 changes: 4 additions & 0 deletions Python/pystate.c
Original file line number Diff line number Diff line change
Expand Up @@ -1654,6 +1654,10 @@ PyThreadState_Clear(PyThreadState *tstate)
"PyThreadState_Clear: warning: thread still has a frame\n");
}

// This calls callbacks registered with PyContext_AddWatcher and can call
// sys.unraisablehook.
_PyContext_ExitThreadOwned(tstate);

/* At this point tstate shouldn't be used any more,
neither to run Python code nor for other uses.

Expand Down
Loading