Skip to content

Commit 5425990

Browse files
committed
Significant rewrite to avoid using thread_locals as much as possible.
Since we can avoid them by checking this atomic, the cmake config conditional shouldn't be necessary. The slower path (with thread_locals and extra checks) only comes in when a second interpreter is actually instanciated.
1 parent ecf287a commit 5425990

File tree

7 files changed

+125
-136
lines changed

7 files changed

+125
-136
lines changed

CMakeLists.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ option(PYBIND11_DISABLE_HANDLE_TYPE_NAME_DEFAULT_IMPLEMENTATION
9292
"To enforce that a handle_type_name<> specialization exists" OFF)
9393
option(PYBIND11_SIMPLE_GIL_MANAGEMENT
9494
"Use simpler GIL management logic that does not support disassociation" OFF)
95-
option(PYBIND11_SUBINTERPRETER_SUPPORT "Enable support for sub-interpreters" OFF)
9695
option(PYBIND11_NUMPY_1_ONLY
9796
"Disable NumPy 2 support to avoid changes to previous pybind11 versions." OFF)
9897
set(PYBIND11_INTERNALS_VERSION
@@ -106,9 +105,6 @@ endif()
106105
if(PYBIND11_SIMPLE_GIL_MANAGEMENT)
107106
add_compile_definitions(PYBIND11_SIMPLE_GIL_MANAGEMENT)
108107
endif()
109-
if(PYBIND11_SUBINTERPRETER_SUPPORT)
110-
add_compile_definitions(PYBIND11_SUBINTERPRETER_SUPPORT)
111-
endif()
112108
if(PYBIND11_NUMPY_1_ONLY)
113109
add_compile_definitions(PYBIND11_NUMPY_1_ONLY)
114110
endif()

include/pybind11/detail/common.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,13 @@
291291
# define PYBIND11_ASSERT_GIL_HELD_INCREF_DECREF
292292
#endif
293293

294+
// Slightly faster code paths are available when this is NOT defined, so undefine it for impls
295+
// that do not have subinterpreter. Nothing breaks if this is defined but the impl does not
296+
// actually support subinterpreters.
297+
#if PY_VERSION_HEX >= 0x030C0000 && !defined(PYPY_VERSION) && !defined(GRAALVM_PYTHON)
298+
# define PYBIND11_SUBINTERPRETER_SUPPORT
299+
#endif
300+
294301
// #define PYBIND11_STR_LEGACY_PERMISSIVE
295302
// If DEFINED, pybind11::str can hold PyUnicodeObject or PyBytesObject
296303
// (probably surprising and never documented, but this was the
@@ -466,6 +473,7 @@ PYBIND11_WARNING_DISABLE_CLANG("-Wgnu-zero-variadic-macro-arguments")
466473
return m.ptr(); \
467474
} \
468475
int PYBIND11_CONCAT(pybind11_exec_, name)(PyObject * pm) { \
476+
pybind11::detail::get_interpreter_count()++; \
469477
try { \
470478
auto m = pybind11::reinterpret_borrow<::pybind11::module_>(pm); \
471479
PYBIND11_CONCAT(pybind11_init_, name)(m); \

include/pybind11/detail/internals.h

Lines changed: 101 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <pybind11/conduit/pybind11_platform_abi_id.h>
1919
#include <pybind11/pytypes.h>
2020

21+
#include <atomic>
2122
#include <exception>
2223
#include <mutex>
2324
#include <thread>
@@ -257,29 +258,47 @@ struct type_info {
257258
"__pybind11_module_local_v" PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) \
258259
PYBIND11_COMPILER_TYPE_LEADING_UNDERSCORE PYBIND11_PLATFORM_ABI_ID "__"
259260

261+
inline PyThreadState *get_thread_state_unchecked() {
262+
#if defined(PYPY_VERSION) || defined(GRAALVM_PYTHON)
263+
return PyThreadState_GET();
264+
#elif PY_VERSION_HEX < 0x030D0000
265+
return _PyThreadState_UncheckedGet();
266+
#else
267+
return PyThreadState_GetUnchecked();
268+
#endif
269+
}
270+
271+
/// We use this count to figure out if there are multiple sub-interpreters currently present.
272+
/// Might be read/written from multiple interpreters in multiple threads at the same time with no
273+
/// syncronization. This must never decrease while any interpreter may be running in any thread!
274+
inline std::atomic<int64_t> &get_interpreter_count() {
275+
static std::atomic<int64_t> counter(0);
276+
return counter;
277+
}
278+
260279
/// Each module locally stores a pointer to the `internals` data. The data
261280
/// itself is shared among modules with the same `PYBIND11_INTERNALS_ID`.
262281
inline internals **&get_internals_pp() {
263-
#if defined(PYPY_VERSION) || defined(GRAALVM_PYTHON) || PY_VERSION_HEX < 0x030C0000 \
264-
|| !defined(PYBIND11_SUBINTERPRETER_SUPPORT)
265-
static internals **internals_pp = nullptr;
266-
#else
267-
static thread_local internals **internals_pp = nullptr;
268-
// This is one per interpreter, we cache it but if the thread changed
269-
// then we need to invalidate our cache
270-
// the caller will find the right value and set it if its null
271-
static thread_local PyThreadState *tstate_cached = nullptr;
272-
# if PY_VERSION_HEX < 0x030D0000
273-
PyThreadState *tstate = _PyThreadState_UncheckedGet();
274-
# else
275-
PyThreadState *tstate = PyThreadState_GetUnchecked();
276-
# endif
277-
if (tstate != tstate_cached) {
278-
tstate_cached = tstate;
279-
internals_pp = nullptr;
282+
#ifdef PYBIND11_SUBINTERPRETER_SUPPORT
283+
if (get_interpreter_count() > 1) {
284+
// Internals is one per interpreter. When multiple interpreters are alive in different
285+
// threads we have to allow them to have different internals, so we need a thread_local.
286+
static thread_local internals **t_internals_pp = nullptr;
287+
// Whenever the interpreter changes we need to invalidate the internals_pp. That is slow,
288+
// so we only do it when the PyThreadState has changed, which indicates the interpreter
289+
// might have changed as well.
290+
static thread_local PyThreadState *tstate_cached = nullptr;
291+
auto *tstate = get_thread_state_unchecked();
292+
if (tstate != tstate_cached) {
293+
tstate_cached = tstate;
294+
// the caller will fetch the instance from the state dict or create a new one
295+
t_internals_pp = nullptr;
296+
}
297+
return t_internals_pp;
280298
}
281299
#endif
282-
return internals_pp;
300+
static internals **s_internals_pp = nullptr;
301+
return s_internals_pp;
283302
}
284303

285304
// forward decl
@@ -410,20 +429,6 @@ inline object get_python_state_dict() {
410429
return state_dict;
411430
}
412431

413-
inline object get_internals_obj_from_state_dict(handle state_dict) {
414-
return reinterpret_steal<object>(
415-
dict_getitemstringref(state_dict.ptr(), PYBIND11_INTERNALS_ID));
416-
}
417-
418-
inline internals **get_internals_pp_from_capsule(handle obj) {
419-
void *raw_ptr = PyCapsule_GetPointer(obj.ptr(), /*name=*/nullptr);
420-
if (raw_ptr == nullptr) {
421-
raise_from(PyExc_SystemError, "pybind11::detail::get_internals_pp_from_capsule() FAILED");
422-
throw error_already_set();
423-
}
424-
return static_cast<internals **>(raw_ptr);
425-
}
426-
427432
inline uint64_t round_up_to_next_pow2(uint64_t x) {
428433
// Round-up to the next power of two.
429434
// See https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
@@ -438,13 +443,8 @@ inline uint64_t round_up_to_next_pow2(uint64_t x) {
438443
return x;
439444
}
440445

441-
/// Return a reference to the current `internals` data
442-
PYBIND11_NOINLINE internals &get_internals() {
443-
auto **&internals_pp = get_internals_pp();
444-
if (internals_pp && *internals_pp) {
445-
return **internals_pp;
446-
}
447-
446+
template <typename InternalsType>
447+
inline InternalsType **find_or_create_internals_pp(char const *state_dict_key) {
448448
// Ensure that the GIL is held since we will need to make Python calls.
449449
// Cannot use py::gil_scoped_acquire here since that constructor calls get_internals.
450450
struct gil_scoped_acquire_local {
@@ -458,10 +458,31 @@ PYBIND11_NOINLINE internals &get_internals() {
458458
error_scope err_scope;
459459

460460
dict state_dict = get_python_state_dict();
461-
if (object internals_obj = get_internals_obj_from_state_dict(state_dict)) {
462-
internals_pp = get_internals_pp_from_capsule(internals_obj);
461+
object internals_obj
462+
= reinterpret_steal<object>(dict_getitemstringref(state_dict.ptr(), state_dict_key));
463+
if (internals_obj) {
464+
void *raw_ptr = PyCapsule_GetPointer(internals_obj.ptr(), /*name=*/nullptr);
465+
if (!raw_ptr) {
466+
pybind11_fail("find_or_create_internals_pp: broken capsule!");
467+
} else {
468+
return reinterpret_cast<InternalsType **>(raw_ptr);
469+
}
463470
}
471+
472+
auto pp = new InternalsType *(nullptr);
473+
state_dict[state_dict_key] = capsule(reinterpret_cast<void *>(pp));
474+
return pp;
475+
}
476+
477+
/// Return a reference to the current `internals` data
478+
PYBIND11_NOINLINE internals &get_internals() {
479+
auto **&internals_pp = get_internals_pp();
464480
if (internals_pp && *internals_pp) {
481+
return **internals_pp;
482+
}
483+
484+
internals_pp = find_or_create_internals_pp<internals>(PYBIND11_INTERNALS_ID);
485+
if (*internals_pp) {
465486
// We loaded the internals through `state_dict`, which means that our `error_already_set`
466487
// and `builtin_exception` may be different local classes than the ones set up in the
467488
// initial exception translator, below, so add another for our local exception classes.
@@ -478,13 +499,10 @@ PYBIND11_NOINLINE internals &get_internals() {
478499
}
479500
#endif
480501
} else {
481-
if (!internals_pp) {
482-
internals_pp = new internals *();
483-
}
484502
auto *&internals_ptr = *internals_pp;
485503
internals_ptr = new internals();
486504

487-
PyThreadState *tstate = PyThreadState_Get();
505+
PyThreadState *tstate = get_thread_state_unchecked();
488506
// NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
489507
if (!PYBIND11_TLS_KEY_CREATE(internals_ptr->tstate)) {
490508
pybind11_fail("get_internals: could not successfully initialize the tstate TSS key!");
@@ -498,7 +516,6 @@ PYBIND11_NOINLINE internals &get_internals() {
498516
}
499517

500518
internals_ptr->istate = tstate->interp;
501-
state_dict[PYBIND11_INTERNALS_ID] = capsule(reinterpret_cast<void *>(internals_pp));
502519
internals_ptr->registered_exception_translators.push_front(&translate_exception);
503520
internals_ptr->static_property_type = make_static_property_type();
504521
internals_ptr->default_metaclass = make_default_metaclass();
@@ -528,69 +545,48 @@ struct local_internals {
528545
std::forward_list<ExceptionTranslator> registered_exception_translators;
529546
};
530547

531-
/// Works like `get_internals`, but for things which are locally registered.
532-
inline local_internals &get_local_internals() {
533-
// Current static can be created in the interpreter finalization routine. If the later will be
534-
// destroyed in another static variable destructor, creation of this static there will cause
535-
// static deinitialization fiasco. In order to avoid it we avoid destruction of the
536-
// local_internals static. One can read more about the problem and current solution here:
537-
// https://google.github.io/styleguide/cppguide.html#Static_and_Global_Variables
538-
539-
#if defined(PYPY_VERSION) || defined(GRAALVM_PYTHON) || PY_VERSION_HEX < 0x030C0000 \
540-
|| !defined(PYBIND11_SUBINTERPRETER_SUPPORT)
541-
static auto *locals = new local_internals();
542-
#else
543-
static thread_local local_internals *locals = nullptr;
544-
// This is one per interpreter, we cache it but if the interpreter changed
545-
// then we need to invalidate our cache and re-fetch from the state dict
546-
static thread_local PyThreadState *tstate_cached = nullptr;
547-
# if PY_VERSION_HEX < 0x030D0000
548-
PyThreadState *tstate = _PyThreadState_UncheckedGet();
549-
# else
550-
PyThreadState *tstate = PyThreadState_GetUnchecked();
551-
# endif
552-
if (!tstate) {
553-
pybind11_fail(
554-
"pybind11::detail::get_local_internals() called without a current python thread");
555-
}
556-
if (tstate != tstate_cached) {
557-
// we create a unique value at first run which is based on a pointer to
558-
// a (non-thread_local) static value in this function, then multiple
559-
// loaded modules using this code will still each have a unique key.
560-
static const std::string this_module_idstr
561-
= PYBIND11_MODULE_LOCAL_ID
562-
+ std::to_string(reinterpret_cast<uintptr_t>(&this_module_idstr));
563-
564-
// Ensure that the GIL is held since we will need to make Python calls.
565-
// Cannot use py::gil_scoped_acquire here since that constructor calls get_internals.
566-
struct gil_scoped_acquire_local {
567-
gil_scoped_acquire_local() : state(PyGILState_Ensure()) {}
568-
gil_scoped_acquire_local(const gil_scoped_acquire_local &) = delete;
569-
gil_scoped_acquire_local &operator=(const gil_scoped_acquire_local &) = delete;
570-
~gil_scoped_acquire_local() { PyGILState_Release(state); }
571-
const PyGILState_STATE state;
572-
} gil;
573-
574-
error_scope err_scope;
575-
dict state_dict = get_python_state_dict();
576-
object local_capsule = reinterpret_steal<object>(
577-
dict_getitemstringref(state_dict.ptr(), this_module_idstr.c_str()));
578-
if (!local_capsule) {
579-
locals = new local_internals();
580-
state_dict[this_module_idstr.c_str()] = capsule(reinterpret_cast<void *>(locals));
581-
} else {
582-
void *ptr = PyCapsule_GetPointer(local_capsule.ptr(), nullptr);
583-
if (!ptr) {
584-
raise_from(PyExc_SystemError, "pybind11::detail::get_local_internals() FAILED");
585-
throw error_already_set();
586-
}
587-
locals = reinterpret_cast<local_internals *>(ptr);
548+
inline local_internals **&get_local_internals_pp() {
549+
#ifdef PYBIND11_SUBINTERPRETER_SUPPORT
550+
if (get_interpreter_count() > 1) {
551+
// Internals is one per interpreter. When multiple interpreters are alive in different
552+
// threads we have to allow them to have different internals, so we need a thread_local.
553+
static thread_local local_internals **t_internals_pp = nullptr;
554+
// Whenever the interpreter changes we need to invalidate the internals_pp. That is slow,
555+
// so we only do it when the PyThreadState has changed, which indicates the interpreter
556+
// might have changed as well.
557+
static thread_local PyThreadState *tstate_cached = nullptr;
558+
auto *tstate = get_thread_state_unchecked();
559+
if (tstate != tstate_cached) {
560+
tstate_cached = tstate;
561+
// the caller will fetch the instance from the state dict or create a new one
562+
t_internals_pp = nullptr;
588563
}
589-
tstate_cached = tstate;
564+
return t_internals_pp;
590565
}
591566
#endif
567+
static local_internals **s_internals_pp = nullptr;
568+
return s_internals_pp;
569+
}
570+
571+
/// Works like `get_internals`, but for things which are locally registered.
572+
PYBIND11_NOINLINE local_internals &get_local_internals() {
573+
auto **&local_internals_pp = get_local_internals_pp();
574+
if (local_internals_pp && *local_internals_pp) {
575+
return **local_internals_pp;
576+
}
577+
578+
// we create a unique value at first run which is based on a pointer to a (non-thread_local)
579+
// static value in this function, then multiple loaded modules using this code will still each
580+
// have a unique key.
581+
static const std::string this_module_idstr
582+
= PYBIND11_MODULE_LOCAL_ID
583+
+ std::to_string(reinterpret_cast<uintptr_t>(&this_module_idstr));
592584

593-
return *locals;
585+
local_internals_pp = find_or_create_internals_pp<local_internals>(this_module_idstr.c_str());
586+
if (!*local_internals_pp) {
587+
*local_internals_pp = new local_internals();
588+
}
589+
return **local_internals_pp;
594590
}
595591

596592
#ifdef Py_GIL_DISABLED

include/pybind11/detail/type_caster_base.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -497,16 +497,6 @@ PYBIND11_NOINLINE handle get_object_handle(const void *ptr, const detail::type_i
497497
});
498498
}
499499

500-
inline PyThreadState *get_thread_state_unchecked() {
501-
#if defined(PYPY_VERSION) || defined(GRAALVM_PYTHON)
502-
return PyThreadState_GET();
503-
#elif PY_VERSION_HEX < 0x030D0000
504-
return _PyThreadState_UncheckedGet();
505-
#else
506-
return PyThreadState_GetUnchecked();
507-
#endif
508-
}
509-
510500
// Forward declarations
511501
void keep_alive_impl(handle nurse, handle patient);
512502
inline PyObject *make_new_instance(PyTypeObject *type);

include/pybind11/embed.h

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ inline void initialize_interpreter(bool init_signal_handlers = true,
193193
config.install_signal_handlers = init_signal_handlers ? 1 : 0;
194194
initialize_interpreter(&config, argc, argv, add_program_dir_to_path);
195195
#endif
196+
197+
detail::get_interpreter_count() = 1;
196198
}
197199

198200
/** \rst
@@ -234,23 +236,24 @@ inline void finalize_interpreter() {
234236
// Get the internals pointer (without creating it if it doesn't exist). It's possible for the
235237
// internals to be created during Py_Finalize() (e.g. if a py::capsule calls `get_internals()`
236238
// during destruction), so we get the pointer-pointer here and check it after Py_Finalize().
237-
detail::internals **internals_ptr_ptr = detail::get_internals_pp();
238-
// It could also be stashed in state_dict, so look there too:
239-
if (object internals_obj
240-
= get_internals_obj_from_state_dict(detail::get_python_state_dict())) {
241-
internals_ptr_ptr = detail::get_internals_pp_from_capsule(internals_obj);
242-
}
243-
// Local internals contains data managed by the current interpreter, so we must clear them to
244-
// avoid undefined behaviors when initializing another interpreter
245-
detail::get_local_internals().registered_types_cpp.clear();
246-
detail::get_local_internals().registered_exception_translators.clear();
239+
auto **&internals_pp = detail::get_internals_pp();
240+
auto **&local_internals_pp = detail::get_local_internals_pp();
247241

248242
Py_Finalize();
249243

250-
if (internals_ptr_ptr) {
251-
delete *internals_ptr_ptr;
252-
*internals_ptr_ptr = nullptr;
244+
if (internals_pp) {
245+
delete *internals_pp;
246+
*internals_pp = nullptr;
253247
}
248+
249+
// Local internals contains data managed by the current interpreter, so we must clear them to
250+
// avoid undefined behaviors when initializing another interpreter
251+
if (local_internals_pp) {
252+
delete *local_internals_pp;
253+
local_internals_pp = nullptr;
254+
}
255+
256+
detail::get_interpreter_count() = 0;
254257
}
255258

256259
/** \rst

include/pybind11/pybind11.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,7 +1291,6 @@ inline bool gil_not_used_option(F &&, O &&...o) {
12911291

12921292
#ifdef Py_mod_multiple_interpreters
12931293
inline void *multi_interp_option() { return Py_MOD_MULTIPLE_INTERPRETERS_NOT_SUPPORTED; }
1294-
# ifdef PYBIND11_SUBINTERPRETER_SUPPORT
12951294
template <typename F, typename... O>
12961295
void *multi_interp_option(F &&, O &&...o);
12971296
template <typename... O>
@@ -1311,7 +1310,6 @@ inline void *multi_interp_option(mod_multi_interpreter_one_gil f, O &&...o) {
13111310
}
13121311
return Py_MOD_MULTIPLE_INTERPRETERS_SUPPORTED;
13131312
}
1314-
# endif
13151313
template <typename F, typename... O>
13161314
inline void *multi_interp_option(F &&, O &&...o) {
13171315
return multi_interp_option(o...);

0 commit comments

Comments
 (0)