|
3 | 3 | #include <pybind11/stl.h> |
4 | 4 | #include <torch/csrc/jit/api/module.h> |
5 | 5 | #include <torch/csrc/utils/pybind.h> |
| 6 | +#include <tuple> |
6 | 7 |
|
7 | 8 | namespace py = pybind11; |
8 | 9 |
|
9 | 10 | namespace torch::jit { |
10 | 11 |
|
11 | 12 | inline std::optional<Module> as_module(py::handle obj) { |
| 13 | +#if IS_PYBIND_2_13_PLUS |
| 14 | + PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object> |
| 15 | + storage; |
| 16 | + auto& ScriptModule = |
| 17 | + storage |
| 18 | + .call_once_and_store_result([]() -> py::object { |
| 19 | + return py::module_::import("torch.jit").attr("ScriptModule"); |
| 20 | + }) |
| 21 | + .get_stored(); |
| 22 | +#else |
12 | 23 | static py::handle ScriptModule = |
13 | 24 | py::module::import("torch.jit").attr("ScriptModule"); |
| 25 | +#endif |
14 | 26 | if (py::isinstance(obj, ScriptModule)) { |
15 | 27 | return py::cast<Module>(obj.attr("_c")); |
16 | 28 | } |
17 | 29 | return std::nullopt; |
18 | 30 | } |
19 | 31 |
|
20 | 32 | inline std::optional<Object> as_object(py::handle obj) { |
| 33 | +#if IS_PYBIND_2_13_PLUS |
| 34 | + PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store< |
| 35 | + std::tuple<py::object, py::object>> |
| 36 | + storage; |
| 37 | + auto& [ScriptObject, RecursiveScriptClass] = |
| 38 | + storage |
| 39 | + .call_once_and_store_result( |
| 40 | + []() -> std::tuple<py::object, py::object> { |
| 41 | + return { |
| 42 | + py::module_::import("torch").attr("ScriptObject"), |
| 43 | + py::module_::import("torch.jit") |
| 44 | + .attr("RecursiveScriptClass")}; |
| 45 | + }) |
| 46 | + .get_stored(); |
| 47 | +#else |
21 | 48 | static py::handle ScriptObject = |
22 | 49 | py::module::import("torch").attr("ScriptObject"); |
23 | | - if (py::isinstance(obj, ScriptObject)) { |
24 | | - return py::cast<Object>(obj); |
25 | | - } |
26 | 50 |
|
27 | 51 | static py::handle RecursiveScriptClass = |
28 | 52 | py::module::import("torch.jit").attr("RecursiveScriptClass"); |
| 53 | +#endif |
| 54 | + |
| 55 | + if (py::isinstance(obj, ScriptObject)) { |
| 56 | + return py::cast<Object>(obj); |
| 57 | + } |
29 | 58 | if (py::isinstance(obj, RecursiveScriptClass)) { |
30 | 59 | return py::cast<Object>(obj.attr("_c")); |
31 | 60 | } |
|
0 commit comments