Skip to content

Commit 17da541

Browse files
committed
Clean up port handling.
1 parent 4ff5673 commit 17da541

File tree

1 file changed

+58
-60
lines changed

1 file changed

+58
-60
lines changed

src/python_bindings.cpp

Lines changed: 58 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,14 @@ class Py_StatefulActionNode final : public StatefulActionNode
5555
}
5656
};
5757

58-
template <typename T>
59-
py::object Py_getInput(const T& node, const std::string& name)
58+
py::object Py_getInput(const TreeNode& node, const std::string& name)
6059
{
6160
py::object obj;
6261
node.getInput(name, obj);
6362
return obj;
6463
}
6564

66-
template <typename T>
67-
void Py_setOutput(T& node, const std::string& name, const py::object& value)
65+
void Py_setOutput(TreeNode& node, const std::string& name, const py::object& value)
6866
{
6967
node.setOutput(name, value);
7068
}
@@ -88,21 +86,49 @@ inline py::object convertFromString(StringView str)
8886
}
8987
}
9088

91-
PYBIND11_MODULE(btpy_cpp, m)
89+
PortsList extractPortsList(const py::type& type)
9290
{
93-
py::class_<PortInfo>(m, "PortInfo");
94-
m.def("input_port",
95-
[](const std::string& name) { return InputPort<py::object>(name); });
96-
m.def("output_port",
97-
[](const std::string& name) { return OutputPort<py::object>(name); });
98-
99-
m.def(
100-
"ports2",
101-
[](const py::list& inputs, const py::list& outputs) -> auto {
102-
return [](py::type type) -> auto { return type; };
103-
},
104-
py::kw_only(), py::arg("inputs") = py::none(), py::arg("outputs") = py::none());
91+
PortsList ports;
92+
93+
const auto input_ports = type.attr("input_ports").cast<py::list>();
94+
for (const auto& name : input_ports)
95+
{
96+
ports.insert(InputPort<py::object>(name.cast<std::string>()));
97+
}
10598

99+
const auto output_ports = type.attr("output_ports").cast<py::list>();
100+
for (const auto& name : output_ports)
101+
{
102+
ports.insert(OutputPort<py::object>(name.cast<std::string>()));
103+
}
104+
105+
return ports;
106+
}
107+
108+
NodeBuilder makeTreeNodeBuilderFn(const py::type& type)
109+
{
110+
return [type](const auto& name, const auto& config) -> auto {
111+
py::object obj = type(name, config);
112+
113+
// TODO: Increment the object's reference count or else it
114+
// will be GC'd at the end of this scope. The downside is
115+
// that, unless we can decrement the ref when the unique_ptr
116+
// is destroyed, then the object will live forever.
117+
obj.inc_ref();
118+
119+
if (py::isinstance<ActionNodeBase>(obj))
120+
{
121+
return std::unique_ptr<TreeNode>(obj.cast<ActionNodeBase*>());
122+
}
123+
else
124+
{
125+
throw std::runtime_error("invalid node type of " + name);
126+
}
127+
};
128+
}
129+
130+
PYBIND11_MODULE(btpy_cpp, m)
131+
{
106132
py::class_<BehaviorTreeFactory>(m, "BehaviorTreeFactory")
107133
.def(py::init())
108134
.def("register",
@@ -112,45 +138,10 @@ PYBIND11_MODULE(btpy_cpp, m)
112138
TreeNodeManifest manifest;
113139
manifest.type = NodeType::ACTION;
114140
manifest.registration_ID = name;
115-
manifest.ports = {};
141+
manifest.ports = extractPortsList(type);
116142
manifest.description = "";
117143

118-
const auto input_ports = type.attr("input_ports").cast<py::list>();
119-
for (const auto& name : input_ports)
120-
{
121-
manifest.ports.insert(InputPort<py::object>(name.cast<std::string>()));
122-
}
123-
124-
const auto output_ports = type.attr("output_ports").cast<py::list>();
125-
for (const auto& name : output_ports)
126-
{
127-
manifest.ports.insert(OutputPort<py::object>(name.cast<std::string>()));
128-
}
129-
130-
factory.registerBuilder(
131-
manifest,
132-
[type](const std::string& name,
133-
const NodeConfig& config) -> std::unique_ptr<TreeNode> {
134-
py::object obj = type(name, config);
135-
// TODO: Increment the object's reference count or else it
136-
// will be GC'd at the end of this scope. The downside is
137-
// that, unless we can decrement the ref when the unique_ptr
138-
// is destroyed, then the object will live forever.
139-
obj.inc_ref();
140-
141-
if (py::isinstance<Py_SyncActionNode>(obj))
142-
{
143-
return std::unique_ptr<TreeNode>(obj.cast<Py_SyncActionNode*>());
144-
}
145-
else if (py::isinstance<Py_StatefulActionNode>(obj))
146-
{
147-
return std::unique_ptr<TreeNode>(obj.cast<Py_StatefulActionNode*>());
148-
}
149-
else
150-
{
151-
throw std::runtime_error("invalid node type of " + name);
152-
}
153-
});
144+
factory.registerBuilder(manifest, makeTreeNodeBuilderFn(type));
154145
})
155146
.def("create_tree_from_text",
156147
[](BehaviorTreeFactory& factory, const std::string& text) -> Tree {
@@ -173,16 +164,23 @@ PYBIND11_MODULE(btpy_cpp, m)
173164

174165
py::class_<NodeConfig>(m, "NodeConfig");
175166

176-
py::class_<Py_SyncActionNode>(m, "SyncActionNode")
167+
// Register the C++ type hierarchy so that we can refer to Python subclasses
168+
// by their superclass ptr types in generic C++ code.
169+
py::class_<TreeNode>(m, "_TreeNode");
170+
py::class_<ActionNodeBase, TreeNode>(m, "_ActionNodeBase");
171+
py::class_<SyncActionNode, ActionNodeBase>(m, "_SyncActionNode");
172+
py::class_<StatefulActionNode, ActionNodeBase>(m, "_StatefulActionNode");
173+
174+
py::class_<Py_SyncActionNode, SyncActionNode>(m, "SyncActionNode")
177175
.def(py::init<const std::string&, const NodeConfig&>())
178-
.def("get_input", &Py_getInput<Py_SyncActionNode>)
179-
.def("set_output", &Py_setOutput<Py_SyncActionNode>)
176+
.def("get_input", &Py_getInput)
177+
.def("set_output", &Py_setOutput)
180178
.def("tick", &Py_SyncActionNode::tick);
181179

182-
py::class_<Py_StatefulActionNode>(m, "StatefulActionNode")
180+
py::class_<Py_StatefulActionNode, StatefulActionNode>(m, "StatefulActionNode")
183181
.def(py::init<const std::string&, const NodeConfig&>())
184-
.def("get_input", &Py_getInput<Py_StatefulActionNode>)
185-
.def("set_output", &Py_setOutput<Py_StatefulActionNode>)
182+
.def("get_input", &Py_getInput)
183+
.def("set_output", &Py_setOutput)
186184
.def("on_start", &Py_StatefulActionNode::onStart)
187185
.def("on_running", &Py_StatefulActionNode::onRunning)
188186
.def("on_halted", &Py_StatefulActionNode::onHalted);

0 commit comments

Comments
 (0)