Skip to content

Commit 975ccf5

Browse files
authored
refactor: remove Py_Get_ID and cached string objects (#263)
1 parent 05691e7 commit 975ccf5

File tree

7 files changed

+37
-80
lines changed

7 files changed

+37
-80
lines changed

include/optree/pymacros.h

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -59,41 +59,3 @@ inline constexpr Py_ALWAYS_INLINE bool Py_IsConstant(PyObject *x) noexcept {
5959
return Py_IsNone(x) || Py_IsTrue(x) || Py_IsFalse(x);
6060
}
6161
#define Py_IsConstant(x) Py_IsConstant(x)
62-
63-
#define Py_Declare_ID(name) \
64-
inline namespace { \
65-
[[nodiscard]] inline PyObject *Py_ID_##name() { \
66-
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<PyObject *> storage; \
67-
return storage \
68-
.call_once_and_store_result([]() -> PyObject * { \
69-
PyObject * const ptr = PyUnicode_InternFromString(#name); \
70-
if (ptr == nullptr) [[unlikely]] { \
71-
throw py::error_already_set(); \
72-
} \
73-
Py_INCREF(ptr); /* leak a reference on purpose */ \
74-
return ptr; \
75-
}) \
76-
.get_stored(); \
77-
} \
78-
} // namespace
79-
80-
#define Py_Get_ID(name) (::Py_ID_##name())
81-
82-
Py_Declare_ID(optree);
83-
Py_Declare_ID(__main__); // __main__
84-
Py_Declare_ID(__module__); // type.__module__
85-
Py_Declare_ID(__qualname__); // type.__qualname__
86-
Py_Declare_ID(__name__); // type.__name__
87-
Py_Declare_ID(sort); // list.sort
88-
Py_Declare_ID(copy); // dict.copy
89-
Py_Declare_ID(OrderedDict); // OrderedDict
90-
Py_Declare_ID(defaultdict); // defaultdict
91-
Py_Declare_ID(deque); // deque
92-
Py_Declare_ID(default_factory); // defaultdict.default_factory
93-
Py_Declare_ID(maxlen); // deque.maxlen
94-
Py_Declare_ID(_fields); // namedtuple._fields
95-
Py_Declare_ID(_make); // namedtuple._make
96-
Py_Declare_ID(_asdict); // namedtuple._asdict
97-
Py_Declare_ID(n_fields); // structseq.n_fields
98-
Py_Declare_ID(n_sequence_fields); // structseq.n_sequence_fields
99-
Py_Declare_ID(n_unnamed_fields); // structseq.n_unnamed_fields

include/optree/pytypes.h

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,7 @@ inline bool IsNamedTupleClassImpl(const py::handle &type) {
218218
// We can only identify namedtuples heuristically, here by the presence of a _fields attribute.
219219
if (PyType_FastSubclass(reinterpret_cast<PyTypeObject *>(type.ptr()),
220220
Py_TPFLAGS_TUPLE_SUBCLASS)) [[unlikely]] {
221-
if (PyObject * const _fields = PyObject_GetAttr(type.ptr(), Py_Get_ID(_fields)))
222-
[[unlikely]] {
221+
if (PyObject * const _fields = PyObject_GetAttrString(type.ptr(), "_fields")) [[unlikely]] {
223222
bool fields_ok = static_cast<bool>(PyTuple_CheckExact(_fields));
224223
if (fields_ok) [[likely]] {
225224
for (const auto &field : py::reinterpret_borrow<py::tuple>(_fields)) {
@@ -232,8 +231,9 @@ inline bool IsNamedTupleClassImpl(const py::handle &type) {
232231
Py_DECREF(_fields);
233232
if (fields_ok) [[likely]] {
234233
// NOLINTNEXTLINE[readability-use-anyofallof]
235-
for (PyObject * const name : {Py_Get_ID(_make), Py_Get_ID(_asdict)}) {
236-
if (PyObject * const attr = PyObject_GetAttr(type.ptr(), name)) [[likely]] {
234+
for (const char * const name : {"_make", "_asdict"}) {
235+
if (PyObject * const attr = PyObject_GetAttrString(type.ptr(), name))
236+
[[likely]] {
237237
const bool result = static_cast<bool>(PyCallable_Check(attr));
238238
Py_DECREF(attr);
239239
if (!result) [[unlikely]] {
@@ -311,7 +311,7 @@ inline py::tuple NamedTupleGetFields(const py::handle &object) {
311311
PyRepr(object) + ".");
312312
}
313313
}
314-
return EVALUATE_WITH_LOCK_HELD(py::getattr(type, Py_Get_ID(_fields)), type);
314+
return EVALUATE_WITH_LOCK_HELD(py::getattr(type, "_fields"), type);
315315
}
316316

317317
inline bool IsStructSequenceClassImpl(const py::handle &type) {
@@ -325,9 +325,8 @@ inline bool IsStructSequenceClassImpl(const py::handle &type) {
325325
PyTuple_GET_ITEM(type_object->tp_bases, 0) == reinterpret_cast<PyObject *>(&PyTuple_Type))
326326
[[unlikely]] {
327327
// NOLINTNEXTLINE[readability-use-anyofallof]
328-
for (PyObject * const name :
329-
{Py_Get_ID(n_fields), Py_Get_ID(n_sequence_fields), Py_Get_ID(n_unnamed_fields)}) {
330-
if (PyObject * const attr = PyObject_GetAttr(type.ptr(), name)) [[unlikely]] {
328+
for (const char * const name : {"n_fields", "n_sequence_fields", "n_unnamed_fields"}) {
329+
if (PyObject * const attr = PyObject_GetAttrString(type.ptr(), name)) [[unlikely]] {
331330
const bool result = static_cast<bool>(PyLong_CheckExact(attr));
332331
Py_DECREF(attr);
333332
if (!result) [[unlikely]] {
@@ -418,7 +417,7 @@ inline py::tuple StructSequenceGetFieldsImpl(const py::handle &type) {
418417
return py::tuple{fields};
419418
#else
420419
const auto n_sequence_fields = thread_safe_cast<py::ssize_t>(
421-
EVALUATE_WITH_LOCK_HELD(py::getattr(type, Py_Get_ID(n_sequence_fields)), type));
420+
EVALUATE_WITH_LOCK_HELD(py::getattr(type, "n_sequence_fields"), type));
422421
const auto * const members = reinterpret_cast<PyTypeObject *>(type.ptr())->tp_members;
423422
py::tuple fields{n_sequence_fields};
424423
for (py::ssize_t i = 0; i < n_sequence_fields; ++i) {
@@ -489,15 +488,15 @@ inline void TotalOrderSort(py::list &list) { // NOLINT[runtime/references]
489488
// Sort with `(f'{obj.__class__.__module__}.{obj.__class__.__qualname__}', obj)`
490489
const auto sort_key_fn = py::cpp_function([](const py::object &obj) -> py::tuple {
491490
const py::handle cls = py::type::handle_of(obj);
492-
const py::str qualname{EVALUATE_WITH_LOCK_HELD(
493-
PyStr(py::getattr(cls, Py_Get_ID(__module__))) + "." +
494-
PyStr(py::getattr(cls, Py_Get_ID(__qualname__))),
495-
cls)};
491+
const py::str qualname{
492+
EVALUATE_WITH_LOCK_HELD(PyStr(py::getattr(cls, "__module__")) + "." +
493+
PyStr(py::getattr(cls, "__qualname__")),
494+
cls)};
496495
return py::make_tuple(qualname, obj);
497496
});
498497
{
499498
const scoped_critical_section cs{list};
500-
py::getattr(list, Py_Get_ID(sort))(py::arg("key") = sort_key_fn);
499+
py::getattr(list, "sort")(py::arg("key") = sort_key_fn);
501500
}
502501
} catch (py::error_already_set &ex2) {
503502
if (ex2.matches(PyExc_TypeError)) [[likely]] {

src/optree.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references]
274274
#endif
275275
auto * const PyTreeKind_Type = reinterpret_cast<PyTypeObject *>(PyTreeKindTypeObject.ptr());
276276
PyTreeKind_Type->tp_name = "optree.PyTreeKind";
277-
py::setattr(PyTreeKindTypeObject, Py_Get_ID(__module__), Py_Get_ID(optree));
277+
py::setattr(PyTreeKindTypeObject, "__module__", py::str("optree"));
278278
py::setattr(PyTreeKindTypeObject, "NUM_KINDS", py::int_(py::ssize_t(PyTreeKind::NumKinds)));
279279

280280
auto PyTreeSpecTypeObject =
@@ -298,7 +298,7 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references]
298298
py::module_local());
299299
auto * const PyTreeSpec_Type = reinterpret_cast<PyTypeObject *>(PyTreeSpecTypeObject.ptr());
300300
PyTreeSpec_Type->tp_name = "optree.PyTreeSpec";
301-
py::setattr(PyTreeSpecTypeObject, Py_Get_ID(__module__), Py_Get_ID(optree));
301+
py::setattr(PyTreeSpecTypeObject, "__module__", py::str("optree"));
302302

303303
PyTreeSpecTypeObject
304304
.def("unflatten",
@@ -496,7 +496,7 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references]
496496
py::module_local());
497497
auto * const PyTreeIter_Type = reinterpret_cast<PyTypeObject *>(PyTreeIterTypeObject.ptr());
498498
PyTreeIter_Type->tp_name = "optree.PyTreeIter";
499-
py::setattr(PyTreeIterTypeObject, Py_Get_ID(__module__), Py_Get_ID(optree));
499+
py::setattr(PyTreeIterTypeObject, "__module__", py::str("optree"));
500500

501501
PyTreeIterTypeObject
502502
.def(py::init<py::object, std::optional<py::function>, bool, std::string>(),

src/treespec/constructors.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ template <bool NoneIsLeaf>
169169
node.arity = DictGetSize(dict);
170170
keys = DictKeys(dict);
171171
if (node.kind != PyTreeKind::OrderedDict) [[likely]] {
172-
node.original_keys = py::getattr(keys, Py_Get_ID(copy))();
172+
node.original_keys = py::getattr(keys, "copy")();
173173
if (!IsDictInsertionOrdered(registry_namespace)) [[likely]] {
174174
TotalOrderSort(keys);
175175
}
@@ -181,8 +181,8 @@ template <bool NoneIsLeaf>
181181
verify_children(children, treespecs);
182182
if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] {
183183
const scoped_critical_section cs{handle};
184-
node.node_data = py::make_tuple(py::getattr(handle, Py_Get_ID(default_factory)),
185-
std::move(keys));
184+
node.node_data =
185+
py::make_tuple(py::getattr(handle, "default_factory"), std::move(keys));
186186
} else [[likely]] {
187187
node.node_data = std::move(keys);
188188
}
@@ -204,8 +204,7 @@ template <bool NoneIsLeaf>
204204
case PyTreeKind::Deque: {
205205
const auto list = thread_safe_cast<py::list>(handle);
206206
node.arity = ListGetSize(list);
207-
node.node_data =
208-
EVALUATE_WITH_LOCK_HELD(py::getattr(handle, Py_Get_ID(maxlen)), handle);
207+
node.node_data = EVALUATE_WITH_LOCK_HELD(py::getattr(handle, "maxlen"), handle);
209208
for (ssize_t i = 0; i < node.arity; ++i) {
210209
children.emplace_back(ListGetItem(list, i));
211210
}

src/treespec/flatten.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle &handle,
106106
node.arity = DictGetSize(dict);
107107
keys = DictKeys(dict);
108108
if (node.kind != PyTreeKind::OrderedDict) [[likely]] {
109-
node.original_keys = py::getattr(keys, Py_Get_ID(copy))();
109+
node.original_keys = py::getattr(keys, "copy")();
110110
if constexpr (DictShouldBeSorted) {
111111
TotalOrderSort(keys);
112112
}
@@ -117,8 +117,8 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle &handle,
117117
}
118118
if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] {
119119
const scoped_critical_section cs{handle};
120-
node.node_data = py::make_tuple(py::getattr(handle, Py_Get_ID(default_factory)),
121-
std::move(keys));
120+
node.node_data =
121+
py::make_tuple(py::getattr(handle, "default_factory"), std::move(keys));
122122
} else [[likely]] {
123123
node.node_data = std::move(keys);
124124
}
@@ -139,8 +139,7 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle &handle,
139139
case PyTreeKind::Deque: {
140140
const auto list = thread_safe_cast<py::list>(handle);
141141
node.arity = ListGetSize(list);
142-
node.node_data =
143-
EVALUATE_WITH_LOCK_HELD(py::getattr(handle, Py_Get_ID(maxlen)), handle);
142+
node.node_data = EVALUATE_WITH_LOCK_HELD(py::getattr(handle, "maxlen"), handle);
144143
for (ssize_t i = 0; i < node.arity; ++i) {
145144
recurse(ListGetItem(list, i));
146145
}
@@ -371,7 +370,7 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle &handle,
371370
node.arity = DictGetSize(dict);
372371
py::list keys = DictKeys(dict);
373372
if (node.kind != PyTreeKind::OrderedDict) [[likely]] {
374-
node.original_keys = py::getattr(keys, Py_Get_ID(copy))();
373+
node.original_keys = py::getattr(keys, "copy")();
375374
if constexpr (DictShouldBeSorted) {
376375
TotalOrderSort(keys);
377376
}
@@ -380,8 +379,8 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle &handle,
380379
recurse(DictGetItem(dict, key), key);
381380
}
382381
if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] {
383-
node.node_data = py::make_tuple(py::getattr(handle, Py_Get_ID(default_factory)),
384-
std::move(keys));
382+
node.node_data =
383+
py::make_tuple(py::getattr(handle, "default_factory"), std::move(keys));
385384
} else [[likely]] {
386385
node.node_data = std::move(keys);
387386
}
@@ -402,8 +401,7 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle &handle,
402401
case PyTreeKind::Deque: {
403402
const auto list = thread_safe_cast<py::list>(handle);
404403
node.arity = ListGetSize(list);
405-
node.node_data =
406-
EVALUATE_WITH_LOCK_HELD(py::getattr(handle, Py_Get_ID(maxlen)), handle);
404+
node.node_data = EVALUATE_WITH_LOCK_HELD(py::getattr(handle, "maxlen"), handle);
407405
for (ssize_t i = 0; i < node.arity; ++i) {
408406
recurse(ListGetItem(list, i), py::int_(i));
409407
}

src/treespec/serialization.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ std::string PyTreeSpec::ToStringImpl() const {
141141
node.arity,
142142
"Number of fields and entries does not match.");
143143
const std::string kind =
144-
PyStr(EVALUATE_WITH_LOCK_HELD(py::getattr(type, Py_Get_ID(__name__)), type));
144+
PyStr(EVALUATE_WITH_LOCK_HELD(py::getattr(type, "__name__"), type));
145145
sstream << kind << "(";
146146
bool first = true;
147147
auto child_it = agenda.cend() - node.arity;
@@ -195,9 +195,8 @@ std::string PyTreeSpec::ToStringImpl() const {
195195
EXPECT_EQ(TupleGetSize(fields),
196196
node.arity,
197197
"Number of fields and entries does not match.");
198-
const py::object module_name = EVALUATE_WITH_LOCK_HELD(
199-
py::getattr(type, Py_Get_ID(__module__), Py_Get_ID(__main__)),
200-
type);
198+
const py::object module_name =
199+
EVALUATE_WITH_LOCK_HELD(py::getattr(type, "__module__", py::none()), type);
201200
if (!module_name.is_none()) [[likely]] {
202201
const std::string name = PyStr(module_name);
203202
if (!(name.empty() || name == "__main__" || name == "builtins" ||
@@ -206,7 +205,7 @@ std::string PyTreeSpec::ToStringImpl() const {
206205
}
207206
}
208207
const py::object qualname =
209-
EVALUATE_WITH_LOCK_HELD(py::getattr(type, Py_Get_ID(__qualname__)), type);
208+
EVALUATE_WITH_LOCK_HELD(py::getattr(type, "__qualname__"), type);
210209
sstream << PyStr(qualname) << "(";
211210
bool first = true;
212211
auto child_it = agenda.cend() - node.arity;
@@ -223,9 +222,9 @@ std::string PyTreeSpec::ToStringImpl() const {
223222
}
224223

225224
case PyTreeKind::Custom: {
226-
const std::string kind = PyStr(
227-
EVALUATE_WITH_LOCK_HELD(py::getattr(node.custom->type, Py_Get_ID(__name__)),
228-
node.custom->type));
225+
const std::string kind =
226+
PyStr(EVALUATE_WITH_LOCK_HELD(py::getattr(node.custom->type, "__name__"),
227+
node.custom->type));
229228
sstream << "CustomTreeNode(" << kind << "[";
230229
if (node.node_data) [[likely]] {
231230
sstream << PyRepr(node.node_data);

src/treespec/treespec.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -856,11 +856,11 @@ py::list PyTreeSpec::Entries() const {
856856
case PyTreeKind::Dict:
857857
case PyTreeKind::OrderedDict: {
858858
const scoped_critical_section cs{root.node_data};
859-
return py::getattr(root.node_data, Py_Get_ID(copy))();
859+
return py::getattr(root.node_data, "copy")();
860860
}
861861
case PyTreeKind::DefaultDict: {
862862
const scoped_critical_section cs{root.node_data};
863-
return py::getattr(TupleGetItem(root.node_data, 1), Py_Get_ID(copy))();
863+
return py::getattr(TupleGetItem(root.node_data, 1), "copy")();
864864
}
865865

866866
case PyTreeKind::NumKinds:

0 commit comments

Comments
 (0)