Skip to content

Commit 067d203

Browse files
albanDpytorchmergebot
authored andcommitted
Upgrade pybind11 API calls for 3.13t (pytorch#136370)
This is a modified version of pytorch#130341 that preserve support for older pybind version. Pull Request resolved: pytorch#136370 Approved by: https://github.com/Skylion007, https://github.com/malfet
1 parent 1a10751 commit 067d203

File tree

5 files changed

+110
-3
lines changed

5 files changed

+110
-3
lines changed

torch/csrc/dynamo/guards.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#endif
2727

2828
#include <sstream>
29+
#include <tuple>
2930
#include <utility>
3031

3132
// For TupleIteratorGetItemAccessor, we need a fast way to retrieve the
@@ -2461,6 +2462,26 @@ std::unique_ptr<GuardManager> make_guard_manager(
24612462
std::string source,
24622463
py::handle example_value,
24632464
py::handle guard_manager_enum) {
2465+
#if IS_PYBIND_2_13_PLUS
2466+
using fourobjects =
2467+
std::tuple<py::object, py::object, py::object, py::object>;
2468+
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<fourobjects>
2469+
storage;
2470+
2471+
auto& [guard_manager_enum_class, base_guard_manager_enum, dict_guard_manager_enum, dict_subclass_guard_manager_enum] =
2472+
storage
2473+
.call_once_and_store_result([]() -> fourobjects {
2474+
py::object guard_manager_enum_class =
2475+
py::module_::import("torch._dynamo.guards")
2476+
.attr("GuardManagerType");
2477+
return {
2478+
guard_manager_enum_class,
2479+
guard_manager_enum_class.attr("GUARD_MANAGER"),
2480+
guard_manager_enum_class.attr("DICT_GUARD_MANAGER"),
2481+
guard_manager_enum_class.attr("DICT_SUBCLASS_GUARD_MANAGER")};
2482+
})
2483+
.get_stored();
2484+
#else
24642485
static py::object guard_manager_enum_class =
24652486
py::module_::import("torch._dynamo.guards").attr("GuardManagerType");
24662487
static py::object base_guard_manager_enum =
@@ -2469,6 +2490,7 @@ std::unique_ptr<GuardManager> make_guard_manager(
24692490
guard_manager_enum_class.attr("DICT_GUARD_MANAGER");
24702491
static py::object dict_subclass_guard_manager_enum =
24712492
guard_manager_enum_class.attr("DICT_SUBCLASS_GUARD_MANAGER");
2493+
#endif
24722494
if (py::isinstance<py::dict>(example_value)) {
24732495
// The purpose of having both DictGuardManager and DictSubclassGuardManager
24742496
// is to handle the variability in how dictionaries and their subclasses

torch/csrc/jit/python/module_python.h

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,58 @@
33
#include <pybind11/stl.h>
44
#include <torch/csrc/jit/api/module.h>
55
#include <torch/csrc/utils/pybind.h>
6+
#include <tuple>
67

78
namespace py = pybind11;
89

910
namespace torch::jit {
1011

1112
inline std::optional<Module> as_module(py::handle obj) {
13+
#if IS_PYBIND_2_13_PLUS
14+
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
15+
storage;
16+
auto& ScriptModule =
17+
storage
18+
.call_once_and_store_result([]() -> py::object {
19+
return py::module_::import("torch.jit").attr("ScriptModule");
20+
})
21+
.get_stored();
22+
#else
1223
static py::handle ScriptModule =
1324
py::module::import("torch.jit").attr("ScriptModule");
25+
#endif
1426
if (py::isinstance(obj, ScriptModule)) {
1527
return py::cast<Module>(obj.attr("_c"));
1628
}
1729
return std::nullopt;
1830
}
1931

2032
inline std::optional<Object> as_object(py::handle obj) {
33+
#if IS_PYBIND_2_13_PLUS
34+
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<
35+
std::tuple<py::object, py::object>>
36+
storage;
37+
auto& [ScriptObject, RecursiveScriptClass] =
38+
storage
39+
.call_once_and_store_result(
40+
[]() -> std::tuple<py::object, py::object> {
41+
return {
42+
py::module_::import("torch").attr("ScriptObject"),
43+
py::module_::import("torch.jit")
44+
.attr("RecursiveScriptClass")};
45+
})
46+
.get_stored();
47+
#else
2148
static py::handle ScriptObject =
2249
py::module::import("torch").attr("ScriptObject");
23-
if (py::isinstance(obj, ScriptObject)) {
24-
return py::cast<Object>(obj);
25-
}
2650

2751
static py::handle RecursiveScriptClass =
2852
py::module::import("torch.jit").attr("RecursiveScriptClass");
53+
#endif
54+
55+
if (py::isinstance(obj, ScriptObject)) {
56+
return py::cast<Object>(obj);
57+
}
2958
if (py::isinstance(obj, RecursiveScriptClass)) {
3059
return py::cast<Object>(obj.attr("_c"));
3160
}

torch/csrc/jit/python/python_ivalue.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,22 @@ struct C10_EXPORT ConcretePyObjectHolder final : PyObjectHolder {
4949
// when using C++. The reason is unclear.
5050
try {
5151
pybind11::gil_scoped_acquire ag;
52+
53+
#if IS_PYBIND_2_13_PLUS
54+
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
55+
storage;
56+
auto& extractorFn =
57+
storage
58+
.call_once_and_store_result([]() -> py::object {
59+
return py::module_::import("torch._jit_internal")
60+
.attr("_extract_tensors");
61+
})
62+
.get_stored();
63+
#else
5264
static py::object& extractorFn = *new py::object(
5365
py::module::import("torch._jit_internal").attr("_extract_tensors"));
66+
#endif
67+
5468
return extractorFn(py_obj_).cast<std::vector<at::Tensor>>();
5569
} catch (py::error_already_set& e) {
5670
auto err = std::runtime_error(

torch/csrc/utils/python_arg_parser.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,22 @@ static py::object maybe_get_registered_torch_dispatch_rule(
266266
// This is a static object, so we must leak the Python object
267267
// "release()" is used here to preserve 1 refcount on the
268268
// object, preventing it from ever being de-allocated by CPython.
269+
#if IS_PYBIND_2_13_PLUS
270+
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
271+
storage;
272+
py::object find_torch_dispatch_rule =
273+
storage
274+
.call_once_and_store_result([]() -> py::object {
275+
return py::module_::import("torch._library.simple_registry")
276+
.attr("find_torch_dispatch_rule");
277+
})
278+
.get_stored();
279+
#else
269280
static const py::handle find_torch_dispatch_rule =
270281
py::object(py::module_::import("torch._library.simple_registry")
271282
.attr("find_torch_dispatch_rule"))
272283
.release();
284+
#endif
273285
auto result = find_torch_dispatch_rule(
274286
py::reinterpret_borrow<py::object>(torch_api_function),
275287
torch_dispatch_object.get_type());

torch/csrc/utils/python_symnode.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,53 @@ namespace torch {
44

55
py::handle get_symint_class() {
66
// NB: leak
7+
#if IS_PYBIND_2_13_PLUS
8+
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
9+
storage;
10+
return storage
11+
.call_once_and_store_result([]() -> py::object {
12+
return py::module::import("torch").attr("SymInt");
13+
})
14+
.get_stored();
15+
#else
716
static py::handle symint_class =
817
py::object(py::module::import("torch").attr("SymInt")).release();
918
return symint_class;
19+
#endif
1020
}
1121

1222
py::handle get_symfloat_class() {
1323
// NB: leak
24+
#if IS_PYBIND_2_13_PLUS
25+
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
26+
storage;
27+
return storage
28+
.call_once_and_store_result([]() -> py::object {
29+
return py::module::import("torch").attr("SymFloat");
30+
})
31+
.get_stored();
32+
#else
1433
static py::handle symfloat_class =
1534
py::object(py::module::import("torch").attr("SymFloat")).release();
1635
return symfloat_class;
36+
#endif
1737
}
1838

1939
py::handle get_symbool_class() {
2040
// NB: leak
41+
#if IS_PYBIND_2_13_PLUS
42+
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
43+
storage;
44+
return storage
45+
.call_once_and_store_result([]() -> py::object {
46+
return py::module::import("torch").attr("SymBool");
47+
})
48+
.get_stored();
49+
#else
2150
static py::handle symbool_class =
2251
py::object(py::module::import("torch").attr("SymBool")).release();
2352
return symbool_class;
53+
#endif
2454
}
2555

2656
} // namespace torch

0 commit comments

Comments
 (0)