@@ -55,16 +55,14 @@ class Py_StatefulActionNode final : public StatefulActionNode
55
55
}
56
56
};
57
57
58
- template <typename T>
59
- py::object Py_getInput (const T& node, const std::string& name)
58
+ py::object Py_getInput (const TreeNode& node, const std::string& name)
60
59
{
61
60
py::object obj;
62
61
node.getInput (name, obj);
63
62
return obj;
64
63
}
65
64
66
- template <typename T>
67
- void Py_setOutput (T& node, const std::string& name, const py::object& value)
65
+ void Py_setOutput (TreeNode& node, const std::string& name, const py::object& value)
68
66
{
69
67
node.setOutput (name, value);
70
68
}
@@ -88,21 +86,49 @@ inline py::object convertFromString(StringView str)
88
86
}
89
87
}
90
88
91
- PYBIND11_MODULE (btpy_cpp, m )
89
+ PortsList extractPortsList ( const py::type& type )
92
90
{
93
- py::class_<PortInfo>(m, " PortInfo" );
94
- m.def (" input_port" ,
95
- [](const std::string& name) { return InputPort<py::object>(name); });
96
- m.def (" output_port" ,
97
- [](const std::string& name) { return OutputPort<py::object>(name); });
98
-
99
- m.def (
100
- " ports2" ,
101
- [](const py::list& inputs, const py::list& outputs) -> auto {
102
- return [](py::type type) -> auto { return type; };
103
- },
104
- py::kw_only (), py::arg (" inputs" ) = py::none (), py::arg (" outputs" ) = py::none ());
91
+ PortsList ports;
92
+
93
+ const auto input_ports = type.attr (" input_ports" ).cast <py::list>();
94
+ for (const auto & name : input_ports)
95
+ {
96
+ ports.insert (InputPort<py::object>(name.cast <std::string>()));
97
+ }
105
98
99
+ const auto output_ports = type.attr (" output_ports" ).cast <py::list>();
100
+ for (const auto & name : output_ports)
101
+ {
102
+ ports.insert (OutputPort<py::object>(name.cast <std::string>()));
103
+ }
104
+
105
+ return ports;
106
+ }
107
+
108
+ NodeBuilder makeTreeNodeBuilderFn (const py::type& type)
109
+ {
110
+ return [type](const auto & name, const auto & config) -> auto {
111
+ py::object obj = type (name, config);
112
+
113
+ // TODO: Increment the object's reference count or else it
114
+ // will be GC'd at the end of this scope. The downside is
115
+ // that, unless we can decrement the ref when the unique_ptr
116
+ // is destroyed, then the object will live forever.
117
+ obj.inc_ref ();
118
+
119
+ if (py::isinstance<ActionNodeBase>(obj))
120
+ {
121
+ return std::unique_ptr<TreeNode>(obj.cast <ActionNodeBase*>());
122
+ }
123
+ else
124
+ {
125
+ throw std::runtime_error (" invalid node type of " + name);
126
+ }
127
+ };
128
+ }
129
+
130
+ PYBIND11_MODULE (btpy_cpp, m)
131
+ {
106
132
py::class_<BehaviorTreeFactory>(m, " BehaviorTreeFactory" )
107
133
.def (py::init ())
108
134
.def (" register" ,
@@ -112,45 +138,10 @@ PYBIND11_MODULE(btpy_cpp, m)
112
138
TreeNodeManifest manifest;
113
139
manifest.type = NodeType::ACTION;
114
140
manifest.registration_ID = name;
115
- manifest.ports = {} ;
141
+ manifest.ports = extractPortsList (type) ;
116
142
manifest.description = " " ;
117
143
118
- const auto input_ports = type.attr (" input_ports" ).cast <py::list>();
119
- for (const auto & name : input_ports)
120
- {
121
- manifest.ports .insert (InputPort<py::object>(name.cast <std::string>()));
122
- }
123
-
124
- const auto output_ports = type.attr (" output_ports" ).cast <py::list>();
125
- for (const auto & name : output_ports)
126
- {
127
- manifest.ports .insert (OutputPort<py::object>(name.cast <std::string>()));
128
- }
129
-
130
- factory.registerBuilder (
131
- manifest,
132
- [type](const std::string& name,
133
- const NodeConfig& config) -> std::unique_ptr<TreeNode> {
134
- py::object obj = type (name, config);
135
- // TODO: Increment the object's reference count or else it
136
- // will be GC'd at the end of this scope. The downside is
137
- // that, unless we can decrement the ref when the unique_ptr
138
- // is destroyed, then the object will live forever.
139
- obj.inc_ref ();
140
-
141
- if (py::isinstance<Py_SyncActionNode>(obj))
142
- {
143
- return std::unique_ptr<TreeNode>(obj.cast <Py_SyncActionNode*>());
144
- }
145
- else if (py::isinstance<Py_StatefulActionNode>(obj))
146
- {
147
- return std::unique_ptr<TreeNode>(obj.cast <Py_StatefulActionNode*>());
148
- }
149
- else
150
- {
151
- throw std::runtime_error (" invalid node type of " + name);
152
- }
153
- });
144
+ factory.registerBuilder (manifest, makeTreeNodeBuilderFn (type));
154
145
})
155
146
.def (" create_tree_from_text" ,
156
147
[](BehaviorTreeFactory& factory, const std::string& text) -> Tree {
@@ -173,16 +164,23 @@ PYBIND11_MODULE(btpy_cpp, m)
173
164
174
165
py::class_<NodeConfig>(m, " NodeConfig" );
175
166
176
- py::class_<Py_SyncActionNode>(m, " SyncActionNode" )
167
+ // Register the C++ type hierarchy so that we can refer to Python subclasses
168
+ // by their superclass ptr types in generic C++ code.
169
+ py::class_<TreeNode>(m, " _TreeNode" );
170
+ py::class_<ActionNodeBase, TreeNode>(m, " _ActionNodeBase" );
171
+ py::class_<SyncActionNode, ActionNodeBase>(m, " _SyncActionNode" );
172
+ py::class_<StatefulActionNode, ActionNodeBase>(m, " _StatefulActionNode" );
173
+
174
+ py::class_<Py_SyncActionNode, SyncActionNode>(m, " SyncActionNode" )
177
175
.def (py::init<const std::string&, const NodeConfig&>())
178
- .def (" get_input" , &Py_getInput<Py_SyncActionNode> )
179
- .def (" set_output" , &Py_setOutput<Py_SyncActionNode> )
176
+ .def (" get_input" , &Py_getInput)
177
+ .def (" set_output" , &Py_setOutput)
180
178
.def (" tick" , &Py_SyncActionNode::tick);
181
179
182
- py::class_<Py_StatefulActionNode>(m, " StatefulActionNode" )
180
+ py::class_<Py_StatefulActionNode, StatefulActionNode >(m, " StatefulActionNode" )
183
181
.def (py::init<const std::string&, const NodeConfig&>())
184
- .def (" get_input" , &Py_getInput<Py_StatefulActionNode> )
185
- .def (" set_output" , &Py_setOutput<Py_StatefulActionNode> )
182
+ .def (" get_input" , &Py_getInput)
183
+ .def (" set_output" , &Py_setOutput)
186
184
.def (" on_start" , &Py_StatefulActionNode::onStart)
187
185
.def (" on_running" , &Py_StatefulActionNode::onRunning)
188
186
.def (" on_halted" , &Py_StatefulActionNode::onHalted);
0 commit comments