diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 4942141c..bed7adea 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -13,6 +13,12 @@ on: - reopened # Allow to trigger the workflow manually workflow_dispatch: + inputs: + nightly-pybind11: + description: "Use nightly pybind11" + type: boolean + required: false + default: false permissions: contents: read @@ -72,8 +78,11 @@ jobs: - name: Install nightly pybind11 shell: bash if: | - github.event_name == 'pull_request' && - contains(github.event.pull_request.labels.*.name, 'test-with-nightly-pybind11') + (github.event_name == 'workflow_dispatch' && github.event.inputs.nightly-pybind11 == 'true') || + ( + github.event_name == 'pull_request' && + contains(github.event.pull_request.labels.*.name, 'test-with-nightly-pybind11') + ) run: | python -m pip install --force-reinstall "$(python .github/workflows/set_setup_requires.py)" echo "::group::pyproject.toml" diff --git a/.github/workflows/tests-with-pydebug.yml b/.github/workflows/tests-with-pydebug.yml index 081c773d..42fa13e1 100644 --- a/.github/workflows/tests-with-pydebug.yml +++ b/.github/workflows/tests-with-pydebug.yml @@ -4,6 +4,9 @@ on: push: branches: - main + schedule: + # Run at 12:00 Asia/Shanghai (04:00 UTC) every three days with nightly pybind11 + - cron: "0 4 */3 * *" pull_request: types: - labeled @@ -24,6 +27,12 @@ on: - .github/workflows/tests-with-pydebug.yml # Allow to trigger the workflow manually workflow_dispatch: + inputs: + nightly-pybind11: + description: "Use nightly pybind11" + type: boolean + required: false + default: false permissions: contents: read @@ -312,11 +321,14 @@ jobs: - name: Use nightly pybind11 shell: bash if: | - github.event_name == 'pull_request' && - contains(github.event.pull_request.labels.*.name, 'test-with-nightly-pybind11') + github.event_name == 'schedule' || + (github.event_name == 'workflow_dispatch' && github.event.inputs.nightly-pybind11 == 'true') || + ( + github.event_name == 'pull_request' && + contains(github.event.pull_request.labels.*.name, 'test-with-nightly-pybind11') + ) run: | source "venv/${VENV_BIN_NAME}/activate" - ${{ env.PYTHON }} .github/workflows/set_setup_requires.py ${{ env.PYTHON }} -m pip install --force-reinstall "$(${{ env.PYTHON }} .github/workflows/set_setup_requires.py)" echo "::group::pyproject.toml" cat pyproject.toml @@ -348,7 +360,13 @@ jobs: "--cov-report=xml:coverage-${{ env.PYTHON_TAG }}-${{ runner.os }}.xml" "--junit-xml=junit-${{ env.PYTHON_TAG }}-${{ runner.os }}.xml" ) - make test PYTESTOPTS="${PYTESTOPTS[*]}" + + if ${{ env.PYTHON }} -c 'import sys, optree; sys.exit(not optree._C.OPTREE_HAS_SUBINTERPRETER_SUPPORT)'; then + make test PYTESTOPTS="${PYTESTOPTS[*]} -k 'concurrent' --no-cov" + make test PYTESTOPTS="${PYTESTOPTS[*]} -k 'not subinterpreter'" + else + make test PYTESTOPTS="${PYTESTOPTS[*]}" + fi CORE_DUMP_FILES="$( find . -type d -path "./venv" -prune \ diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index df6ea824..399c6870 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -25,6 +25,12 @@ on: - .github/workflows/tests.yml # Allow to trigger the workflow manually workflow_dispatch: + inputs: + nightly-pybind11: + description: "Use nightly pybind11" + type: boolean + required: false + default: false permissions: contents: read @@ -186,10 +192,12 @@ jobs: - name: Use nightly pybind11 shell: bash if: | - github.event_name == 'pull_request' && - contains(github.event.pull_request.labels.*.name, 'test-with-nightly-pybind11') + (github.event_name == 'workflow_dispatch' && github.event.inputs.nightly-pybind11 == 'true') || + ( + github.event_name == 'pull_request' && + contains(github.event.pull_request.labels.*.name, 'test-with-nightly-pybind11') + ) run: | - ${{ env.PYTHON }} .github/workflows/set_setup_requires.py ${{ env.PYTHON }} -m pip install --force-reinstall "$(${{ env.PYTHON }} .github/workflows/set_setup_requires.py)" echo "::group::pyproject.toml" cat pyproject.toml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6dc1f76f..3f0635ee 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,7 +38,7 @@ repos: hooks: - id: cpplint - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.14 + rev: v0.15.1 hooks: - id: ruff-check args: [--fix, --exit-non-zero-on-fix] diff --git a/CHANGELOG.md b/CHANGELOG.md index 7be653bd..27ff93e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Add subinterpreters support for Python 3.14+ by [@XuehaiPan](https://github.com/XuehaiPan) in [#245](https://github.com/metaopt/optree/pull/245). ### Changed diff --git a/include/optree/pymacros.h b/include/optree/pymacros.h index 56e5cc01..f2fb251f 100644 --- a/include/optree/pymacros.h +++ b/include/optree/pymacros.h @@ -17,6 +17,8 @@ limitations under the License. #pragma once +#include // std::runtime_error + #include #include @@ -32,6 +34,15 @@ limitations under the License. // NOLINTNEXTLINE[bugprone-macro-parentheses] #define NONZERO_OR_EMPTY(MACRO) ((MACRO + 0 != 0) || (0 - MACRO - 1 >= 0)) +#if !defined(PYPY_VERSION) && (PY_VERSION_HEX >= 0x030E0000 /* Python 3.14 */) && \ + (PYBIND11_VERSION_HEX >= 0x030002A0 /* pybind11 3.0.2.a0 */) && \ + (defined(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) && \ + NONZERO_OR_EMPTY(PYBIND11_HAS_SUBINTERPRETER_SUPPORT)) +# define OPTREE_HAS_SUBINTERPRETER_SUPPORT 1 +#else +# undef OPTREE_HAS_SUBINTERPRETER_SUPPORT +#endif + namespace py = pybind11; #if !defined(Py_ALWAYS_INLINE) @@ -59,3 +70,50 @@ inline constexpr Py_ALWAYS_INLINE bool Py_IsConstant(PyObject *x) noexcept { return Py_IsNone(x) || Py_IsTrue(x) || Py_IsFalse(x); } #define Py_IsConstant(x) Py_IsConstant(x) + +using interpid_t = decltype(PyInterpreterState_GetID(nullptr)); + +#if defined(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) && \ + NONZERO_OR_EMPTY(PYBIND11_HAS_SUBINTERPRETER_SUPPORT) + +[[nodiscard]] inline bool IsCurrentPyInterpreterMain() { + return PyInterpreterState_Get() == PyInterpreterState_Main(); +} + +[[nodiscard]] inline interpid_t GetCurrentPyInterpreterID() { + PyInterpreterState *interp = PyInterpreterState_Get(); + if (PyErr_Occurred() != nullptr) [[unlikely]] { + throw py::error_already_set(); + } + if (interp == nullptr) [[unlikely]] { + throw std::runtime_error("Failed to get the current Python interpreter state."); + } + const interpid_t interpid = PyInterpreterState_GetID(interp); + if (PyErr_Occurred() != nullptr) [[unlikely]] { + throw py::error_already_set(); + } + return interpid; +} + +[[nodiscard]] inline interpid_t GetMainPyInterpreterID() { + PyInterpreterState *interp = PyInterpreterState_Main(); + if (PyErr_Occurred() != nullptr) [[unlikely]] { + throw py::error_already_set(); + } + if (interp == nullptr) [[unlikely]] { + throw std::runtime_error("Failed to get the main Python interpreter state."); + } + const interpid_t interpid = PyInterpreterState_GetID(interp); + if (PyErr_Occurred() != nullptr) [[unlikely]] { + throw py::error_already_set(); + } + return interpid; +} + +#else + +[[nodiscard]] inline bool IsCurrentPyInterpreterMain() noexcept { return true; } +[[nodiscard]] inline interpid_t GetCurrentPyInterpreterID() noexcept { return 0; } +[[nodiscard]] inline interpid_t GetMainPyInterpreterID() noexcept { return 0; } + +#endif diff --git a/include/optree/registry.h b/include/optree/registry.h index ac91146d..fc3cc624 100644 --- a/include/optree/registry.h +++ b/include/optree/registry.h @@ -23,12 +23,13 @@ limitations under the License. #include // std::string #include // std::unordered_map #include // std::unordered_set -#include // std::pair +#include // std::pair, std::make_pair #include #include "optree/exceptions.h" #include "optree/hashing.h" +#include "optree/pymacros.h" #include "optree/synchronization.h" namespace optree { @@ -141,6 +142,52 @@ class PyTreeTypeRegistry { return count1; } + // Get the number of alive interpreters that have seen the registry. + [[nodiscard]] static inline Py_ALWAYS_INLINE ssize_t GetNumInterpretersAlive() { + const scoped_read_lock lock{sm_mutex}; + return py::ssize_t_cast(sm_alive_interpids.size()); + } + + // Get the number of interpreters that have seen the registry. + [[nodiscard]] static inline Py_ALWAYS_INLINE ssize_t GetNumInterpretersSeen() { + const scoped_read_lock lock{sm_mutex}; + return sm_num_interpreters_seen; + } + + // Get the IDs of alive interpreters that have seen the registry. + [[nodiscard]] static inline Py_ALWAYS_INLINE std::unordered_set + GetAliveInterpreterIDs() { + const scoped_read_lock lock{sm_mutex}; + return sm_alive_interpids; + } + + // Check if should preserve the insertion order of the dictionary keys during flattening. + [[nodiscard]] static inline Py_ALWAYS_INLINE bool IsDictInsertionOrdered( + const std::string ®istry_namespace, + const bool &inherit_global_namespace = true) { + const scoped_read_lock lock{sm_dict_order_mutex}; + + const auto interpid = GetCurrentPyInterpreterID(); + const auto &namespaces = sm_dict_insertion_ordered_namespaces; + return (namespaces.find({interpid, registry_namespace}) != namespaces.end()) || + (inherit_global_namespace && namespaces.find({interpid, ""}) != namespaces.end()); + } + + // Set the namespace to preserve the insertion order of the dictionary keys during flattening. + static inline Py_ALWAYS_INLINE void SetDictInsertionOrdered( + const bool &mode, + const std::string ®istry_namespace) { + const scoped_write_lock lock{sm_dict_order_mutex}; + + const auto interpid = GetCurrentPyInterpreterID(); + const auto key = std::make_pair(interpid, registry_namespace); + if (mode) [[likely]] { + sm_dict_insertion_ordered_namespaces.insert(key); + } else [[unlikely]] { + sm_dict_insertion_ordered_namespaces.erase(key); + } + } + friend void BuildModule(py::module_ &mod); // NOLINT[runtime/references] private: @@ -173,7 +220,16 @@ class PyTreeTypeRegistry { NamedRegistrationsMap m_named_registrations{}; BuiltinsTypesSet m_builtins_types{}; + // A set of namespaces that preserve the insertion order of the dictionary keys during + // flattening. + static inline std::unordered_set> + sm_dict_insertion_ordered_namespaces{}; + static inline read_write_mutex sm_dict_order_mutex{}; + friend class PyTreeSpec; + + static inline std::unordered_set sm_alive_interpids{}; static inline read_write_mutex sm_mutex{}; + static inline ssize_t sm_num_interpreters_seen = 0; }; } // namespace optree diff --git a/include/optree/synchronization.h b/include/optree/synchronization.h index 83203f48..a5f84c3c 100644 --- a/include/optree/synchronization.h +++ b/include/optree/synchronization.h @@ -62,7 +62,7 @@ using scoped_recursive_lock = std::scoped_lock; #if (defined(__APPLE__) /* header is not available on macOS build target */ && \ PY_VERSION_HEX < 0x030C00F0 /* Python 3.12.0 */) -# undef HAVE_READ_WRITE_LOCK +# undef OPTREE_HAS_READ_WRITE_LOCK using read_write_mutex = mutex; using scoped_read_lock = scoped_lock; @@ -70,7 +70,7 @@ using scoped_write_lock = scoped_lock; #else -# define HAVE_READ_WRITE_LOCK +# define OPTREE_HAS_READ_WRITE_LOCK 1 # include // std::shared_mutex, std::shared_lock diff --git a/include/optree/treespec.h b/include/optree/treespec.h index ecc6079e..c73684f6 100644 --- a/include/optree/treespec.h +++ b/include/optree/treespec.h @@ -17,14 +17,13 @@ limitations under the License. #pragma once -#include // std::unique_ptr -#include // std::optional, std::nullopt -#include // std::string -#include // std::thread::id -#include // std::tuple -#include // std::unordered_set -#include // std::pair -#include // std::vector +#include // std::unique_ptr +#include // std::optional, std::nullopt +#include // std::string +#include // std::thread::id +#include // std::tuple +#include // std::pair +#include // std::vector #include @@ -259,31 +258,6 @@ class PyTreeSpec { const bool &none_is_leaf = false, const std::string ®istry_namespace = ""); - // Check if should preserve the insertion order of the dictionary keys during flattening. - [[nodiscard]] static inline Py_ALWAYS_INLINE bool IsDictInsertionOrdered( - const std::string ®istry_namespace, - const bool &inherit_global_namespace = true) { - const scoped_read_lock lock{sm_is_dict_insertion_ordered_mutex}; - - return (sm_is_dict_insertion_ordered.find(registry_namespace) != - sm_is_dict_insertion_ordered.end()) || - (inherit_global_namespace && - sm_is_dict_insertion_ordered.find("") != sm_is_dict_insertion_ordered.end()); - } - - // Set the namespace to preserve the insertion order of the dictionary keys during flattening. - static inline Py_ALWAYS_INLINE void SetDictInsertionOrdered( - const bool &mode, - const std::string ®istry_namespace) { - const scoped_write_lock lock{sm_is_dict_insertion_ordered_mutex}; - - if (mode) [[likely]] { - sm_is_dict_insertion_ordered.insert(registry_namespace); - } else [[unlikely]] { - sm_is_dict_insertion_ordered.erase(registry_namespace); - } - } - friend void BuildModule(py::module_ &mod); // NOLINT[runtime/references] private: @@ -423,11 +397,6 @@ class PyTreeSpec { // Used in tp_clear for GC support. static int PyTpClear(PyObject *self_base); - - // A set of namespaces that preserve the insertion order of the dictionary keys during - // flattening. - static inline std::unordered_set sm_is_dict_insertion_ordered{}; - static inline read_write_mutex sm_is_dict_insertion_ordered_mutex{}; }; class PyTreeIter { @@ -441,7 +410,8 @@ class PyTreeIter { m_leaf_predicate{leaf_predicate}, m_none_is_leaf{none_is_leaf}, m_namespace{registry_namespace}, - m_is_dict_insertion_ordered{PyTreeSpec::IsDictInsertionOrdered(registry_namespace)} {} + m_is_dict_insertion_ordered{ + PyTreeTypeRegistry::IsDictInsertionOrdered(registry_namespace)} {} PyTreeIter() = delete; ~PyTreeIter() = default; diff --git a/optree/_C.pyi b/optree/_C.pyi index e7b8217b..81de1d63 100644 --- a/optree/_C.pyi +++ b/optree/_C.pyi @@ -49,10 +49,14 @@ Py_DEBUG: Final[bool] Py_GIL_DISABLED: Final[bool] PYBIND11_VERSION_HEX: Final[int] PYBIND11_INTERNALS_VERSION: Final[int] +PYBIND11_INTERNALS_ID: Final[str] +PYBIND11_MODULE_LOCAL_ID: Final[str] PYBIND11_HAS_NATIVE_ENUM: Final[bool] PYBIND11_HAS_INTERNALS_WITH_SMART_HOLDER_SUPPORT: Final[bool] PYBIND11_HAS_SUBINTERPRETER_SUPPORT: Final[bool] GLIBCXX_USE_CXX11_ABI: Final[bool] +OPTREE_HAS_SUBINTERPRETER_SUPPORT: Final[bool] +OPTREE_HAS_READ_WRITE_LOCK: Final[bool] @final class InternalError(SystemError): ... @@ -214,3 +218,9 @@ def set_dict_insertion_ordered( namespace: str = '', ) -> None: ... def get_registry_size(namespace: str | None = None) -> int: ... +def get_num_interpreters_seen() -> int: ... +def get_num_interpreters_alive() -> int: ... +def get_alive_interpreter_ids() -> set[int]: ... +def is_current_interpreter_main() -> bool: ... +def get_current_interpreter_id() -> int: ... +def get_main_interpreter_id() -> int: ... diff --git a/optree/version.py b/optree/version.py index 384d68bb..40f621e2 100644 --- a/optree/version.py +++ b/optree/version.py @@ -22,16 +22,16 @@ __release__ = False if not __release__: - import os import subprocess + from pathlib import Path - root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + root_dir = Path(__file__).absolute().parent.parent try: prefix, sep, suffix = ( subprocess.check_output( # noqa: S603 [ # noqa: S607 'git', - f'--git-dir={os.path.join(root_dir, ".git")}', + f'--git-dir={root_dir / ".git"}', 'describe', '--abbrev=7', ], @@ -39,6 +39,7 @@ stderr=subprocess.DEVNULL, text=True, encoding='utf-8', + timeout=120.0, ) .strip() .lstrip('v') @@ -54,7 +55,7 @@ else: __version__ = prefix del prefix, sep, suffix - except (OSError, subprocess.CalledProcessError): + except (OSError, RuntimeError, subprocess.SubprocessError): pass - del os, subprocess, root_dir + del Path, subprocess, root_dir diff --git a/setup.py b/setup.py index bf591d70..4cabbc64 100644 --- a/setup.py +++ b/setup.py @@ -101,8 +101,9 @@ def cmake_context( stderr=subprocess.STDOUT, text=True, encoding='utf-8', + timeout=120.0, ).strip() - except (OSError, subprocess.CalledProcessError): + except (OSError, subprocess.SubprocessError): eprint( f'Could not run `{cmake}` directly. ' 'Unset the `PYTHONPATH` environment variable in the build environment.', @@ -115,6 +116,7 @@ def cmake_context( stderr=subprocess.STDOUT, text=True, encoding='utf-8', + timeout=120.0, ).strip() if verbose and output: @@ -160,9 +162,10 @@ def cmake_executable( stderr=subprocess.DEVNULL, text=True, encoding='utf-8', + timeout=120.0, ), ) - except (OSError, subprocess.CalledProcessError, json.JSONDecodeError): + except (OSError, subprocess.SubprocessError, json.JSONDecodeError): cmake_capabilities = {} cmake_version = cmake_capabilities.get('version', {}).get('string', '0.0.0') if Version(cmake_version) < Version(minimum_version): diff --git a/src/optree.cpp b/src/optree.cpp index 3adef798..5adb5f06 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -73,6 +73,8 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] #endif BUILDTIME_METADATA["PYBIND11_VERSION_HEX"] = py::int_(PYBIND11_VERSION_HEX); BUILDTIME_METADATA["PYBIND11_INTERNALS_VERSION"] = py::int_(PYBIND11_INTERNALS_VERSION); + BUILDTIME_METADATA["PYBIND11_INTERNALS_ID"] = py::str(PYBIND11_INTERNALS_ID); + BUILDTIME_METADATA["PYBIND11_MODULE_LOCAL_ID"] = py::str(PYBIND11_MODULE_LOCAL_ID); #if defined(PYBIND11_HAS_NATIVE_ENUM) && NONZERO_OR_EMPTY(PYBIND11_HAS_NATIVE_ENUM) BUILDTIME_METADATA["PYBIND11_HAS_NATIVE_ENUM"] = py::bool_(true); #else @@ -95,6 +97,16 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] #else BUILDTIME_METADATA["GLIBCXX_USE_CXX11_ABI"] = py::bool_(false); #endif +#if defined(OPTREE_HAS_SUBINTERPRETER_SUPPORT) + BUILDTIME_METADATA["OPTREE_HAS_SUBINTERPRETER_SUPPORT"] = py::bool_(true); +#else + BUILDTIME_METADATA["OPTREE_HAS_SUBINTERPRETER_SUPPORT"] = py::bool_(false); +#endif +#if defined(OPTREE_HAS_READ_WRITE_LOCK) + BUILDTIME_METADATA["OPTREE_HAS_READ_WRITE_LOCK"] = py::bool_(true); +#else + BUILDTIME_METADATA["OPTREE_HAS_READ_WRITE_LOCK"] = py::bool_(false); +#endif mod.attr("BUILDTIME_METADATA") = std::move(BUILDTIME_METADATA); py::exec( @@ -139,12 +151,12 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] py::pos_only(), py::arg("namespace") = "") .def("is_dict_insertion_ordered", - &PyTreeSpec::IsDictInsertionOrdered, + &PyTreeTypeRegistry::IsDictInsertionOrdered, "Return whether need to preserve the dict insertion order during flattening.", py::arg("namespace") = "", py::arg("inherit_global_namespace") = true) .def("set_dict_insertion_ordered", - &PyTreeSpec::SetDictInsertionOrdered, + &PyTreeTypeRegistry::SetDictInsertionOrdered, "Set whether need to preserve the dict insertion order during flattening.", py::arg("mode"), py::pos_only(), @@ -153,6 +165,24 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] &PyTreeTypeRegistry::GetRegistrySize, "Get the number of registered types.", py::arg("namespace") = std::nullopt) + .def("get_num_interpreters_seen", + &PyTreeTypeRegistry::GetNumInterpretersSeen, + "Get the number of interpreters that have seen the registry.") + .def("get_num_interpreters_alive", + &PyTreeTypeRegistry::GetNumInterpretersAlive, + "Get the number of alive interpreters that have seen the registry.") + .def("get_alive_interpreter_ids", + &PyTreeTypeRegistry::GetAliveInterpreterIDs, + "Get the IDs of alive interpreters that have seen the registry.") + .def("is_current_interpreter_main", + &IsCurrentPyInterpreterMain, + "Check whether the current interpreter is the main interpreter.") + .def("get_current_interpreter_id", + &GetCurrentPyInterpreterID, + "Get the ID of the current interpreter.") + .def("get_main_interpreter_id", + &GetMainPyInterpreterID, + "Get the ID of the main interpreter.") .def("flatten", &PyTreeSpec::Flatten, "Flatten a pytree.", @@ -528,7 +558,11 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] // NOLINTBEGIN[cppcoreguidelines-pro-bounds-pointer-arithmetic,cppcoreguidelines-pro-type-vararg] #if PYBIND11_VERSION_HEX >= 0x020D00F0 // pybind11 2.13.0 +# if defined(OPTREE_HAS_SUBINTERPRETER_SUPPORT) +PYBIND11_MODULE(_C, mod, py::mod_gil_not_used(), py::multiple_interpreters::per_interpreter_gil()) +# else PYBIND11_MODULE(_C, mod, py::mod_gil_not_used()) +# endif #else PYBIND11_MODULE(_C, mod) #endif diff --git a/src/registry.cpp b/src/registry.cpp index faa6c662..4e1a97ff 100644 --- a/src/registry.cpp +++ b/src/registry.cpp @@ -310,6 +310,14 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( /*static*/ void PyTreeTypeRegistry::Init() { const scoped_write_lock lock{sm_mutex}; + const auto interpid = GetCurrentPyInterpreterID(); + + ++sm_num_interpreters_seen; + EXPECT_TRUE( + sm_alive_interpids.insert(interpid).second, + "The current interpreter ID should not be already present in the alive interpreters " + "set."); + auto ®istry1 = GetSingleton(); auto ®istry2 = GetSingleton(); @@ -325,6 +333,32 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( /*static*/ void PyTreeTypeRegistry::Clear() { const scoped_write_lock lock{sm_mutex}; + const auto interpid = GetCurrentPyInterpreterID(); + + EXPECT_NE(sm_alive_interpids.find(interpid), + sm_alive_interpids.end(), + "The current interpreter ID should be present in the alive interpreters set."); + sm_alive_interpids.erase(interpid); + + { + const scoped_write_lock namespace_lock{sm_dict_order_mutex}; + auto entries = reserved_vector(4); + for (const auto &entry : sm_dict_insertion_ordered_namespaces) { + if (entry.first == interpid) [[likely]] { + entries.emplace_back(entry); + } + } + for (const auto &entry : entries) { + sm_dict_insertion_ordered_namespaces.erase(entry); + } + if (sm_alive_interpids.empty()) [[likely]] { + EXPECT_TRUE( + sm_dict_insertion_ordered_namespaces.empty(), + "The dict insertion ordered namespaces map should be empty when there is no " + "alive Python interpreter."); + } + } + auto ®istry1 = GetSingleton(); auto ®istry2 = GetSingleton(); diff --git a/src/treespec/constructors.cpp b/src/treespec/constructors.cpp index cc4ec090..49c861ca 100644 --- a/src/treespec/constructors.cpp +++ b/src/treespec/constructors.cpp @@ -170,7 +170,8 @@ template keys = DictKeys(dict); if (node.kind != PyTreeKind::OrderedDict) [[likely]] { node.original_keys = py::getattr(keys, "copy")(); - if (!IsDictInsertionOrdered(registry_namespace)) [[likely]] { + if (!PyTreeTypeRegistry::IsDictInsertionOrdered(registry_namespace)) + [[likely]] { TotalOrderSort(keys); } } diff --git a/src/treespec/flatten.cpp b/src/treespec/flatten.cpp index 8c630224..16733eb1 100644 --- a/src/treespec/flatten.cpp +++ b/src/treespec/flatten.cpp @@ -207,12 +207,13 @@ bool PyTreeSpec::FlattenInto(const py::handle &handle, bool is_dict_insertion_ordered = false; bool is_dict_insertion_ordered_in_current_namespace = false; { -#if defined(HAVE_READ_WRITE_LOCK) - const scoped_read_lock lock{sm_is_dict_insertion_ordered_mutex}; +#if defined(OPTREE_HAS_READ_WRITE_LOCK) + const scoped_read_lock lock{PyTreeTypeRegistry::sm_dict_order_mutex}; #endif - is_dict_insertion_ordered = IsDictInsertionOrdered(registry_namespace); + is_dict_insertion_ordered = PyTreeTypeRegistry::IsDictInsertionOrdered(registry_namespace); is_dict_insertion_ordered_in_current_namespace = - IsDictInsertionOrdered(registry_namespace, /*inherit_global_namespace=*/false); + PyTreeTypeRegistry::IsDictInsertionOrdered(registry_namespace, + /*inherit_global_namespace=*/false); } if (none_is_leaf) [[unlikely]] { @@ -483,12 +484,13 @@ bool PyTreeSpec::FlattenIntoWithPath(const py::handle &handle, bool is_dict_insertion_ordered = false; bool is_dict_insertion_ordered_in_current_namespace = false; { -#if defined(HAVE_READ_WRITE_LOCK) - const scoped_read_lock lock{sm_is_dict_insertion_ordered_mutex}; +#if defined(OPTREE_HAS_READ_WRITE_LOCK) + const scoped_read_lock lock{PyTreeTypeRegistry::sm_dict_order_mutex}; #endif - is_dict_insertion_ordered = IsDictInsertionOrdered(registry_namespace); + is_dict_insertion_ordered = PyTreeTypeRegistry::IsDictInsertionOrdered(registry_namespace); is_dict_insertion_ordered_in_current_namespace = - IsDictInsertionOrdered(registry_namespace, /*inherit_global_namespace=*/false); + PyTreeTypeRegistry::IsDictInsertionOrdered(registry_namespace, + /*inherit_global_namespace=*/false); } auto stack = reserved_vector(4); diff --git a/tests/concurrent/test_subinterpreters.py b/tests/concurrent/test_subinterpreters.py new file mode 100644 index 00000000..16381466 --- /dev/null +++ b/tests/concurrent/test_subinterpreters.py @@ -0,0 +1,347 @@ +# Copyright 2022-2025 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import atexit +import contextlib +import platform +import random +import sys + +import pytest + +from helpers import ( + ANDROID, + IOS, + OPTREE_HAS_SUBINTERPRETER_SUPPORT, + PYPY, + WASM, + Py_DEBUG, + check_script_in_subprocess, +) + + +if ( + PYPY + or WASM + or IOS + or ANDROID + or sys.version_info < (3, 14) + or not getattr(sys.implementation, 'supports_isolated_interpreters', False) + or not OPTREE_HAS_SUBINTERPRETER_SUPPORT +): + pytest.skip('Test for CPython 3.14+ only', allow_module_level=True) + + +from concurrent import interpreters +from concurrent.futures import InterpreterPoolExecutor, as_completed + + +if not Py_DEBUG: + NUM_WORKERS = 8 + NUM_FUTURES = 32 + NUM_FLAKY_RERUNS = 16 +else: + NUM_WORKERS = 4 + NUM_FUTURES = 16 + NUM_FLAKY_RERUNS = 8 + + +EXECUTOR = InterpreterPoolExecutor(max_workers=NUM_WORKERS) +atexit.register(EXECUTOR.shutdown) + + +def run(func, /, *args, **kwargs): + future = EXECUTOR.submit(func, *args, **kwargs) + exception = future.exception() + if exception is not None: + raise exception + return future.result() + + +def concurrent_run(func, /, *args, **kwargs): + futures = [EXECUTOR.submit(func, *args, **kwargs) for _ in range(NUM_FUTURES)] + future2index = {future: i for i, future in enumerate(futures)} + completed_futures = sorted(as_completed(futures), key=future2index.get) + first_exception = next(filter(None, (future.exception() for future in completed_futures)), None) + if first_exception is not None: + raise first_exception + return [future.result() for future in completed_futures] + + +def check_module_importable(): + import collections + import time + + import optree + import optree._C + + is_current_interpreter_main = optree._C.is_current_interpreter_main() + main_interpreter_id = optree._C.get_main_interpreter_id() + current_interpreter_id = optree._C.get_current_interpreter_id() + + if is_current_interpreter_main != (main_interpreter_id == current_interpreter_id): + raise RuntimeError('interpreter identity mismatch') + + if not is_current_interpreter_main and optree._C.get_registry_size() != 8: + raise RuntimeError('registry size mismatch') + + tree = { + 'b': [2, (3, 4)], + 'a': 1, + 'c': collections.OrderedDict( + f=None, + d=5, + e=time.struct_time([6] + [None] * (time.struct_time.n_sequence_fields - 1)), + ), + 'g': collections.defaultdict(list, h=collections.deque([7, 8, 9], maxlen=10)), + } + + leaves1, treespec1 = optree.tree_flatten(tree, none_is_leaf=False) + reconstructed1 = optree.tree_unflatten(treespec1, leaves1) + if reconstructed1 != tree: + raise RuntimeError('unflatten/flatten mismatch') + if treespec1.num_leaves != len(leaves1): + raise RuntimeError(f'num_leaves mismatch: ({leaves1}, {treespec1})') + if leaves1 != [1, 2, 3, 4, 5, 6, 7, 8, 9]: + raise RuntimeError(f'flattened leaves mismatch: ({leaves1}, {treespec1})') + + leaves2, treespec2 = optree.tree_flatten(tree, none_is_leaf=True) + reconstructed2 = optree.tree_unflatten(treespec2, leaves2) + if reconstructed2 != tree: + raise RuntimeError('unflatten/flatten mismatch') + if treespec2.num_leaves != len(leaves2): + raise RuntimeError(f'num_leaves mismatch: ({leaves2}, {treespec2})') + if leaves2 != [ + 1, + 2, + 3, + 4, + None, + 5, + 6, + *([None] * (time.struct_time.n_sequence_fields - 1)), + 7, + 8, + 9, + ]: + raise RuntimeError(f'flattened leaves mismatch: ({leaves2}, {treespec2})') + + _ = optree.tree_flatten_with_path(tree, none_is_leaf=False) + _ = optree.tree_flatten_with_path(tree, none_is_leaf=True) + _ = optree.tree_flatten_with_accessor(tree, none_is_leaf=False) + _ = optree.tree_flatten_with_accessor(tree, none_is_leaf=True) + + return ( + is_current_interpreter_main, + main_interpreter_id, + id(type(None)), + id(tuple), + id(list), + id(dict), + id(collections.OrderedDict), + ) + + +def test_import(): + import collections + + expected = ( + False, + 0, + id(type(None)), + id(tuple), + id(list), + id(dict), + id(collections.OrderedDict), + ) + + assert check_module_importable() == (True, *expected[1:]) + assert run(check_module_importable) == expected + + for _ in range(random.randint(5, 10)): + with contextlib.closing(interpreters.create()) as subinterpreter: + subinterpreter.exec('import optree') + with contextlib.closing(interpreters.create()) as subinterpreter: + assert subinterpreter.call(check_module_importable) == expected + + for actual in concurrent_run(check_module_importable): + assert actual == expected + + with contextlib.ExitStack() as stack: + subinterpreters = [ + stack.enter_context(contextlib.closing(interpreters.create())) + for _ in range(random.randint(5, 10)) + ] + random.shuffle(subinterpreters) + for subinterpreter in subinterpreters: + subinterpreter.exec('import optree') + + with contextlib.ExitStack() as stack: + subinterpreters = [ + stack.enter_context(contextlib.closing(interpreters.create())) + for _ in range(random.randint(5, 10)) + ] + random.shuffle(subinterpreters) + for subinterpreter in subinterpreters: + assert subinterpreter.call(check_module_importable) == expected + + +def test_import_in_subinterpreter_after_main(): + check_script_in_subprocess( + """ + import contextlib + import gc + from concurrent import interpreters + + import optree + + subinterpreter = None + with contextlib.closing(interpreters.create()) as subinterpreter: + subinterpreter.exec('import optree') + + del optree, subinterpreter + for _ in range(10): + gc.collect() + """, + output='', + rerun=NUM_FLAKY_RERUNS, + ) + + check_script_in_subprocess( + f""" + import contextlib + import gc + import random + from concurrent import interpreters + + import optree + + subinterpreter = subinterpreters = stack = None + with contextlib.ExitStack() as stack: + subinterpreters = [ + stack.enter_context(contextlib.closing(interpreters.create())) + for _ in range({NUM_FUTURES}) + ] + random.shuffle(subinterpreters) + for subinterpreter in subinterpreters: + subinterpreter.exec('import optree') + + del optree, subinterpreter, subinterpreters, stack + for _ in range(10): + gc.collect() + """, + output='', + rerun=NUM_FLAKY_RERUNS, + ) + + +def test_import_in_subinterpreter_before_main(): + check_script_in_subprocess( + """ + import contextlib + import gc + from concurrent import interpreters + + subinterpreter = None + with contextlib.closing(interpreters.create()) as subinterpreter: + subinterpreter.exec('import optree') + + import optree + + del optree, subinterpreter + for _ in range(10): + gc.collect() + """, + output='', + rerun=NUM_FLAKY_RERUNS, + ) + + check_script_in_subprocess( + f""" + import contextlib + import gc + import random + from concurrent import interpreters + + subinterpreter = subinterpreters = stack = None + with contextlib.ExitStack() as stack: + subinterpreters = [ + stack.enter_context(contextlib.closing(interpreters.create())) + for _ in range({NUM_FUTURES}) + ] + random.shuffle(subinterpreters) + for subinterpreter in subinterpreters: + subinterpreter.exec('import optree') + + import optree + + del optree, subinterpreter, subinterpreters, stack + for _ in range(10): + gc.collect() + """, + output='', + rerun=NUM_FLAKY_RERUNS, + ) + + check_script_in_subprocess( + f""" + import contextlib + import gc + import random + from concurrent import interpreters + + subinterpreter = subinterpreters = stack = None + with contextlib.ExitStack() as stack: + subinterpreters = [ + stack.enter_context(contextlib.closing(interpreters.create())) + for _ in range({NUM_FUTURES}) + ] + random.shuffle(subinterpreters) + for subinterpreter in subinterpreters: + subinterpreter.exec('import optree') + + import optree + + del optree, subinterpreter, subinterpreters, stack + for _ in range(10): + gc.collect() + """, + output='', + rerun=NUM_FLAKY_RERUNS, + ) + + +@pytest.mark.flaky(condition=platform.system() == 'Windows', reruns=3, only_rerun='TimeoutExpired') +def test_import_in_subinterpreters_concurrently(): + check_script_in_subprocess( + f""" + from concurrent.futures import InterpreterPoolExecutor, as_completed + + def check_import(): + import optree + + if optree._C.get_registry_size() != 8: + raise RuntimeError('registry size mismatch') + if optree._C.is_current_interpreter_main(): + raise RuntimeError('expected subinterpreter') + + with InterpreterPoolExecutor(max_workers={NUM_WORKERS}) as executor: + futures = [executor.submit(check_import) for _ in range({NUM_FUTURES})] + for future in as_completed(futures): + future.result() + """, + output='', + rerun=NUM_FLAKY_RERUNS, + ) diff --git a/tests/confcoverage.py b/tests/confcoverage.py index 5635dee2..78ca9d9a 100644 --- a/tests/confcoverage.py +++ b/tests/confcoverage.py @@ -60,8 +60,9 @@ def is_importable(mod: str) -> bool: ], cwd=TEST_ROOT, env=env, + timeout=120.0, ) - except subprocess.CalledProcessError: + except subprocess.SubprocessError: return False return True diff --git a/tests/helpers.py b/tests/helpers.py index c068ab2e..a2fc8200 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -36,6 +36,7 @@ import optree from optree._C import ( + OPTREE_HAS_SUBINTERPRETER_SUPPORT, PYBIND11_HAS_NATIVE_ENUM, PYBIND11_HAS_SUBINTERPRETER_SUPPORT, Py_DEBUG, @@ -55,6 +56,7 @@ _ = PYBIND11_HAS_NATIVE_ENUM _ = PYBIND11_HAS_SUBINTERPRETER_SUPPORT +_ = OPTREE_HAS_SUBINTERPRETER_SUPPORT if sysconfig.get_config_var('Py_DEBUG') is None: assert Py_DEBUG == hasattr(sys, 'gettotalrefcount') @@ -168,7 +170,16 @@ def __str__(self): ) -def check_script_in_subprocess(script, /, *, output, env=None, cwd=TEST_ROOT, rerun=1): +def check_script_in_subprocess( + script, + /, + *, + output, + timeout=120.0, + cwd=TEST_ROOT, + env=None, + rerun=1, +): script = textwrap.dedent(script).strip() result = '' for _ in range(rerun): @@ -178,6 +189,7 @@ def check_script_in_subprocess(script, /, *, output, env=None, cwd=TEST_ROOT, re stderr=subprocess.STDOUT, text=True, encoding='utf-8', + timeout=timeout, cwd=cwd, env={ key: value diff --git a/tests/test_treespec.py b/tests/test_treespec.py index e6aeb165..35928b36 100644 --- a/tests/test_treespec.py +++ b/tests/test_treespec.py @@ -575,6 +575,7 @@ def test_treespec_pickle_missing_registration(): or key in ('PYTHON_GIL', 'PYTHONDEVMODE', 'PYTHONHASHSEED') ) }, + timeout=120.0, ) message = output.stdout.strip() except subprocess.CalledProcessError as ex: