Skip to content

Commit 7a3421c

Browse files
committed
Fixed incompatible naming for stl map bindings
1 parent 6f614ac commit 7a3421c

File tree

2 files changed

+138
-6
lines changed

2 files changed

+138
-6
lines changed

generate_stubs.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,6 @@
5050
overloads = overloads + f'\\1@overload\\1def create(self, arg0: typing.Type[node.{node}]) -> node.{node}: ...'
5151
final_stubs = re.sub(r"([\s]*)def create\(self, arg0: object\) -> Node: ...", f'{overloads}', stubs_import)
5252

53-
# Modify "*View" naming
54-
nodes = re.findall('View\\[(\S*)\\]', final_stubs)
55-
final_stubs = re.sub(r"View\[(\S*)\]", f'View_\\1', final_stubs)
56-
5753
# Writeout changes
5854
file.seek(0)
5955
file.write(final_stubs)

src/DeviceBindings.cpp

Lines changed: 138 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,142 @@
1818
PYBIND11_MAKE_OPAQUE(std::unordered_map<std::int8_t, dai::BoardConfig::GPIO>);
1919
PYBIND11_MAKE_OPAQUE(std::unordered_map<std::int8_t, dai::BoardConfig::UART>);
2020

21+
// Patch for bind_map naming
22+
// Remove if it gets mainlined in pybind11
23+
namespace pybind11 {
24+
25+
template <typename Map, typename holder_type = std::unique_ptr<Map>, typename... Args>
26+
class_<Map, holder_type> bind_map_patched(handle scope, const std::string &name, Args &&...args) {
27+
using KeyType = typename Map::key_type;
28+
using MappedType = typename Map::mapped_type;
29+
using KeysView = detail::keys_view<Map>;
30+
using ValuesView = detail::values_view<Map>;
31+
using ItemsView = detail::items_view<Map>;
32+
using Class_ = class_<Map, holder_type>;
33+
34+
// If either type is a non-module-local bound type then make the map binding non-local as well;
35+
// otherwise (e.g. both types are either module-local or converting) the map will be
36+
// module-local.
37+
auto *tinfo = detail::get_type_info(typeid(MappedType));
38+
bool local = !tinfo || tinfo->module_local;
39+
if (local) {
40+
tinfo = detail::get_type_info(typeid(KeyType));
41+
local = !tinfo || tinfo->module_local;
42+
}
43+
44+
Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward<Args>(args)...);
45+
class_<KeysView> keys_view(
46+
scope, ("KeysView_" + name).c_str(), pybind11::module_local(local));
47+
class_<ValuesView> values_view(
48+
scope, ("ValuesView_" + name).c_str(), pybind11::module_local(local));
49+
class_<ItemsView> items_view(
50+
scope, ("ItemsView_" + name).c_str(), pybind11::module_local(local));
51+
52+
cl.def(init<>());
53+
54+
// Register stream insertion operator (if possible)
55+
detail::map_if_insertion_operator<Map, Class_>(cl, name);
56+
57+
cl.def(
58+
"__bool__",
59+
[](const Map &m) -> bool { return !m.empty(); },
60+
"Check whether the map is nonempty");
61+
62+
cl.def(
63+
"__iter__",
64+
[](Map &m) { return make_key_iterator(m.begin(), m.end()); },
65+
keep_alive<0, 1>() /* Essential: keep map alive while iterator exists */
66+
);
67+
68+
cl.def(
69+
"keys",
70+
[](Map &m) { return KeysView{m}; },
71+
keep_alive<0, 1>() /* Essential: keep map alive while view exists */
72+
);
73+
74+
cl.def(
75+
"values",
76+
[](Map &m) { return ValuesView{m}; },
77+
keep_alive<0, 1>() /* Essential: keep map alive while view exists */
78+
);
79+
80+
cl.def(
81+
"items",
82+
[](Map &m) { return ItemsView{m}; },
83+
keep_alive<0, 1>() /* Essential: keep map alive while view exists */
84+
);
85+
86+
cl.def(
87+
"__getitem__",
88+
[](Map &m, const KeyType &k) -> MappedType & {
89+
auto it = m.find(k);
90+
if (it == m.end()) {
91+
throw key_error();
92+
}
93+
return it->second;
94+
},
95+
return_value_policy::reference_internal // ref + keepalive
96+
);
97+
98+
cl.def("__contains__", [](Map &m, const KeyType &k) -> bool {
99+
auto it = m.find(k);
100+
if (it == m.end()) {
101+
return false;
102+
}
103+
return true;
104+
});
105+
// Fallback for when the object is not of the key type
106+
cl.def("__contains__", [](Map &, const object &) -> bool { return false; });
107+
108+
// Assignment provided only if the type is copyable
109+
detail::map_assignment<Map, Class_>(cl);
110+
111+
cl.def("__delitem__", [](Map &m, const KeyType &k) {
112+
auto it = m.find(k);
113+
if (it == m.end()) {
114+
throw key_error();
115+
}
116+
m.erase(it);
117+
});
118+
119+
cl.def("__len__", &Map::size);
120+
121+
keys_view.def("__len__", [](KeysView &view) { return view.map.size(); });
122+
keys_view.def(
123+
"__iter__",
124+
[](KeysView &view) { return make_key_iterator(view.map.begin(), view.map.end()); },
125+
keep_alive<0, 1>() /* Essential: keep view alive while iterator exists */
126+
);
127+
keys_view.def("__contains__", [](KeysView &view, const KeyType &k) -> bool {
128+
auto it = view.map.find(k);
129+
if (it == view.map.end()) {
130+
return false;
131+
}
132+
return true;
133+
});
134+
// Fallback for when the object is not of the key type
135+
keys_view.def("__contains__", [](KeysView &, const object &) -> bool { return false; });
136+
137+
values_view.def("__len__", [](ValuesView &view) { return view.map.size(); });
138+
values_view.def(
139+
"__iter__",
140+
[](ValuesView &view) { return make_value_iterator(view.map.begin(), view.map.end()); },
141+
keep_alive<0, 1>() /* Essential: keep view alive while iterator exists */
142+
);
143+
144+
items_view.def("__len__", [](ItemsView &view) { return view.map.size(); });
145+
items_view.def(
146+
"__iter__",
147+
[](ItemsView &view) { return make_iterator(view.map.begin(), view.map.end()); },
148+
keep_alive<0, 1>() /* Essential: keep view alive while iterator exists */
149+
);
150+
151+
return cl;
152+
}
153+
154+
} // namespace pybind11
155+
156+
21157
// Searches for available devices (as Device constructor)
22158
// but pooling, to check for python interrupts, and releases GIL in between
23159

@@ -203,8 +339,8 @@ void DeviceBindings::bind(pybind11::module& m, void* pCallstack){
203339
py::class_<PyClock> clock(m, "Clock");
204340

205341

206-
py::bind_map<std::unordered_map<std::int8_t, dai::BoardConfig::GPIO>>(boardConfig, "GPIOMap");
207-
py::bind_map<std::unordered_map<std::int8_t, dai::BoardConfig::UART>>(boardConfig, "UARTMap");
342+
py::bind_map_patched<std::unordered_map<std::int8_t, dai::BoardConfig::GPIO>>(boardConfig, "GPIOMap");
343+
py::bind_map_patched<std::unordered_map<std::int8_t, dai::BoardConfig::UART>>(boardConfig, "UARTMap");
208344

209345

210346
///////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)