1818#include < pybind11/conduit/pybind11_platform_abi_id.h>
1919#include < pybind11/pytypes.h>
2020
21+ #include < atomic>
2122#include < exception>
2223#include < mutex>
2324#include < thread>
@@ -257,29 +258,47 @@ struct type_info {
257258 " __pybind11_module_local_v" PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) \
258259 PYBIND11_COMPILER_TYPE_LEADING_UNDERSCORE PYBIND11_PLATFORM_ABI_ID " __"
259260
261+ inline PyThreadState *get_thread_state_unchecked () {
262+ #if defined(PYPY_VERSION) || defined(GRAALVM_PYTHON)
263+ return PyThreadState_GET ();
264+ #elif PY_VERSION_HEX < 0x030D0000
265+ return _PyThreadState_UncheckedGet ();
266+ #else
267+ return PyThreadState_GetUnchecked ();
268+ #endif
269+ }
270+
271+ // / We use this count to figure out if there are multiple sub-interpreters currently present.
272+ // / Might be read/written from multiple interpreters in multiple threads at the same time with no
273+ // / syncronization. This must never decrease while any interpreter may be running in any thread!
274+ inline std::atomic<int64_t > &get_interpreter_count () {
275+ static std::atomic<int64_t > counter (0 );
276+ return counter;
277+ }
278+
260279// / Each module locally stores a pointer to the `internals` data. The data
261280// / itself is shared among modules with the same `PYBIND11_INTERNALS_ID`.
262281inline internals **&get_internals_pp () {
263- #if defined(PYPY_VERSION) || defined(GRAALVM_PYTHON) || PY_VERSION_HEX < 0x030C0000 \
264- || !defined (PYBIND11_SUBINTERPRETER_SUPPORT)
265- static internals **internals_pp = nullptr ;
266- #else
267- static thread_local internals **internals_pp = nullptr ;
268- // This is one per interpreter, we cache it but if the thread changed
269- // then we need to invalidate our cache
270- // the caller will find the right value and set it if its null
271- static thread_local PyThreadState *tstate_cached = nullptr ;
272- # if PY_VERSION_HEX < 0x030D0000
273- PyThreadState *tstate = _PyThreadState_UncheckedGet ();
274- # else
275- PyThreadState *tstate = PyThreadState_GetUnchecked ();
276- # endif
277- if (tstate != tstate_cached) {
278- tstate_cached = tstate;
279- internals_pp = nullptr ;
282+ #ifdef PYBIND11_SUBINTERPRETER_SUPPORT
283+ if (get_interpreter_count () > 1 ) {
284+ // Internals is one per interpreter. When multiple interpreters are alive in different
285+ // threads we have to allow them to have different internals, so we need a thread_local.
286+ static thread_local internals **t_internals_pp = nullptr ;
287+ // Whenever the interpreter changes we need to invalidate the internals_pp. That is slow,
288+ // so we only do it when the PyThreadState has changed, which indicates the interpreter
289+ // might have changed as well.
290+ static thread_local PyThreadState *tstate_cached = nullptr ;
291+ auto *tstate = get_thread_state_unchecked ();
292+ if (tstate != tstate_cached) {
293+ tstate_cached = tstate;
294+ // the caller will fetch the instance from the state dict or create a new one
295+ t_internals_pp = nullptr ;
296+ }
297+ return t_internals_pp;
280298 }
281299#endif
282- return internals_pp;
300+ static internals **s_internals_pp = nullptr ;
301+ return s_internals_pp;
283302}
284303
285304// forward decl
@@ -410,20 +429,6 @@ inline object get_python_state_dict() {
410429 return state_dict;
411430}
412431
413- inline object get_internals_obj_from_state_dict (handle state_dict) {
414- return reinterpret_steal<object>(
415- dict_getitemstringref (state_dict.ptr (), PYBIND11_INTERNALS_ID));
416- }
417-
418- inline internals **get_internals_pp_from_capsule (handle obj) {
419- void *raw_ptr = PyCapsule_GetPointer (obj.ptr (), /* name=*/ nullptr );
420- if (raw_ptr == nullptr ) {
421- raise_from (PyExc_SystemError, " pybind11::detail::get_internals_pp_from_capsule() FAILED" );
422- throw error_already_set ();
423- }
424- return static_cast <internals **>(raw_ptr);
425- }
426-
427432inline uint64_t round_up_to_next_pow2 (uint64_t x) {
428433 // Round-up to the next power of two.
429434 // See https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
@@ -438,13 +443,8 @@ inline uint64_t round_up_to_next_pow2(uint64_t x) {
438443 return x;
439444}
440445
441- // / Return a reference to the current `internals` data
442- PYBIND11_NOINLINE internals &get_internals () {
443- auto **&internals_pp = get_internals_pp ();
444- if (internals_pp && *internals_pp) {
445- return **internals_pp;
446- }
447-
446+ template <typename InternalsType>
447+ inline InternalsType **find_or_create_internals_pp (char const *state_dict_key) {
448448 // Ensure that the GIL is held since we will need to make Python calls.
449449 // Cannot use py::gil_scoped_acquire here since that constructor calls get_internals.
450450 struct gil_scoped_acquire_local {
@@ -458,10 +458,31 @@ PYBIND11_NOINLINE internals &get_internals() {
458458 error_scope err_scope;
459459
460460 dict state_dict = get_python_state_dict ();
461- if (object internals_obj = get_internals_obj_from_state_dict (state_dict)) {
462- internals_pp = get_internals_pp_from_capsule (internals_obj);
461+ object internals_obj
462+ = reinterpret_steal<object>(dict_getitemstringref (state_dict.ptr (), state_dict_key));
463+ if (internals_obj) {
464+ void *raw_ptr = PyCapsule_GetPointer (internals_obj.ptr (), /* name=*/ nullptr );
465+ if (!raw_ptr) {
466+ pybind11_fail (" find_or_create_internals_pp: broken capsule!" );
467+ } else {
468+ return reinterpret_cast <InternalsType **>(raw_ptr);
469+ }
463470 }
471+
472+ auto pp = new InternalsType *(nullptr );
473+ state_dict[state_dict_key] = capsule (reinterpret_cast <void *>(pp));
474+ return pp;
475+ }
476+
477+ // / Return a reference to the current `internals` data
478+ PYBIND11_NOINLINE internals &get_internals () {
479+ auto **&internals_pp = get_internals_pp ();
464480 if (internals_pp && *internals_pp) {
481+ return **internals_pp;
482+ }
483+
484+ internals_pp = find_or_create_internals_pp<internals>(PYBIND11_INTERNALS_ID);
485+ if (*internals_pp) {
465486 // We loaded the internals through `state_dict`, which means that our `error_already_set`
466487 // and `builtin_exception` may be different local classes than the ones set up in the
467488 // initial exception translator, below, so add another for our local exception classes.
@@ -478,13 +499,10 @@ PYBIND11_NOINLINE internals &get_internals() {
478499 }
479500#endif
480501 } else {
481- if (!internals_pp) {
482- internals_pp = new internals *();
483- }
484502 auto *&internals_ptr = *internals_pp;
485503 internals_ptr = new internals ();
486504
487- PyThreadState *tstate = PyThreadState_Get ();
505+ PyThreadState *tstate = get_thread_state_unchecked ();
488506 // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
489507 if (!PYBIND11_TLS_KEY_CREATE (internals_ptr->tstate )) {
490508 pybind11_fail (" get_internals: could not successfully initialize the tstate TSS key!" );
@@ -498,7 +516,6 @@ PYBIND11_NOINLINE internals &get_internals() {
498516 }
499517
500518 internals_ptr->istate = tstate->interp ;
501- state_dict[PYBIND11_INTERNALS_ID] = capsule (reinterpret_cast <void *>(internals_pp));
502519 internals_ptr->registered_exception_translators .push_front (&translate_exception);
503520 internals_ptr->static_property_type = make_static_property_type ();
504521 internals_ptr->default_metaclass = make_default_metaclass ();
@@ -528,69 +545,48 @@ struct local_internals {
528545 std::forward_list<ExceptionTranslator> registered_exception_translators;
529546};
530547
531- // / Works like `get_internals`, but for things which are locally registered.
532- inline local_internals &get_local_internals () {
533- // Current static can be created in the interpreter finalization routine. If the later will be
534- // destroyed in another static variable destructor, creation of this static there will cause
535- // static deinitialization fiasco. In order to avoid it we avoid destruction of the
536- // local_internals static. One can read more about the problem and current solution here:
537- // https://google.github.io/styleguide/cppguide.html#Static_and_Global_Variables
538-
539- #if defined(PYPY_VERSION) || defined(GRAALVM_PYTHON) || PY_VERSION_HEX < 0x030C0000 \
540- || !defined (PYBIND11_SUBINTERPRETER_SUPPORT)
541- static auto *locals = new local_internals ();
542- #else
543- static thread_local local_internals *locals = nullptr ;
544- // This is one per interpreter, we cache it but if the interpreter changed
545- // then we need to invalidate our cache and re-fetch from the state dict
546- static thread_local PyThreadState *tstate_cached = nullptr ;
547- # if PY_VERSION_HEX < 0x030D0000
548- PyThreadState *tstate = _PyThreadState_UncheckedGet ();
549- # else
550- PyThreadState *tstate = PyThreadState_GetUnchecked ();
551- # endif
552- if (!tstate) {
553- pybind11_fail (
554- " pybind11::detail::get_local_internals() called without a current python thread" );
555- }
556- if (tstate != tstate_cached) {
557- // we create a unique value at first run which is based on a pointer to
558- // a (non-thread_local) static value in this function, then multiple
559- // loaded modules using this code will still each have a unique key.
560- static const std::string this_module_idstr
561- = PYBIND11_MODULE_LOCAL_ID
562- + std::to_string (reinterpret_cast <uintptr_t >(&this_module_idstr));
563-
564- // Ensure that the GIL is held since we will need to make Python calls.
565- // Cannot use py::gil_scoped_acquire here since that constructor calls get_internals.
566- struct gil_scoped_acquire_local {
567- gil_scoped_acquire_local () : state(PyGILState_Ensure()) {}
568- gil_scoped_acquire_local (const gil_scoped_acquire_local &) = delete ;
569- gil_scoped_acquire_local &operator =(const gil_scoped_acquire_local &) = delete ;
570- ~gil_scoped_acquire_local () { PyGILState_Release (state); }
571- const PyGILState_STATE state;
572- } gil;
573-
574- error_scope err_scope;
575- dict state_dict = get_python_state_dict ();
576- object local_capsule = reinterpret_steal<object>(
577- dict_getitemstringref (state_dict.ptr (), this_module_idstr.c_str ()));
578- if (!local_capsule) {
579- locals = new local_internals ();
580- state_dict[this_module_idstr.c_str ()] = capsule (reinterpret_cast <void *>(locals));
581- } else {
582- void *ptr = PyCapsule_GetPointer (local_capsule.ptr (), nullptr );
583- if (!ptr) {
584- raise_from (PyExc_SystemError, " pybind11::detail::get_local_internals() FAILED" );
585- throw error_already_set ();
586- }
587- locals = reinterpret_cast <local_internals *>(ptr);
548+ inline local_internals **&get_local_internals_pp () {
549+ #ifdef PYBIND11_SUBINTERPRETER_SUPPORT
550+ if (get_interpreter_count () > 1 ) {
551+ // Internals is one per interpreter. When multiple interpreters are alive in different
552+ // threads we have to allow them to have different internals, so we need a thread_local.
553+ static thread_local local_internals **t_internals_pp = nullptr ;
554+ // Whenever the interpreter changes we need to invalidate the internals_pp. That is slow,
555+ // so we only do it when the PyThreadState has changed, which indicates the interpreter
556+ // might have changed as well.
557+ static thread_local PyThreadState *tstate_cached = nullptr ;
558+ auto *tstate = get_thread_state_unchecked ();
559+ if (tstate != tstate_cached) {
560+ tstate_cached = tstate;
561+ // the caller will fetch the instance from the state dict or create a new one
562+ t_internals_pp = nullptr ;
588563 }
589- tstate_cached = tstate ;
564+ return t_internals_pp ;
590565 }
591566#endif
567+ static local_internals **s_internals_pp = nullptr ;
568+ return s_internals_pp;
569+ }
570+
571+ // / Works like `get_internals`, but for things which are locally registered.
572+ PYBIND11_NOINLINE local_internals &get_local_internals () {
573+ auto **&local_internals_pp = get_local_internals_pp ();
574+ if (local_internals_pp && *local_internals_pp) {
575+ return **local_internals_pp;
576+ }
577+
578+ // we create a unique value at first run which is based on a pointer to a (non-thread_local)
579+ // static value in this function, then multiple loaded modules using this code will still each
580+ // have a unique key.
581+ static const std::string this_module_idstr
582+ = PYBIND11_MODULE_LOCAL_ID
583+ + std::to_string (reinterpret_cast <uintptr_t >(&this_module_idstr));
592584
593- return *locals;
585+ local_internals_pp = find_or_create_internals_pp<local_internals>(this_module_idstr.c_str ());
586+ if (!*local_internals_pp) {
587+ *local_internals_pp = new local_internals ();
588+ }
589+ return **local_internals_pp;
594590}
595591
596592#ifdef Py_GIL_DISABLED
0 commit comments