Skip to content

Commit f560500

Browse files
committed
Add stateful action bindings.
1 parent 9fdcda1 commit f560500

File tree

2 files changed

+136
-2
lines changed

2 files changed

+136
-2
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#!/usr/bin/env python3
2+
3+
"""
4+
Demonstration of stateful action nodes.
5+
6+
To run, ensure that the `btpy_cpp` Python extension is on your `PYTHONPATH`
7+
variable. It is probably located in your build directory if you're building from
8+
source.
9+
"""
10+
11+
import numpy as np
12+
from btpy import (
13+
BehaviorTreeFactory,
14+
StatefulActionNode,
15+
SyncActionNode,
16+
NodeStatus,
17+
ports,
18+
)
19+
20+
21+
xml_text = """
22+
<root BTCPP_format="4" >
23+
24+
<BehaviorTree ID="MainTree">
25+
<Sequence>
26+
<!-- Initialize the interpolated position -->
27+
<SetBlackboard output_key="interpolated" value="None" />
28+
29+
<!-- Interpolate from the initial position to the final one printing
30+
at each step. -->
31+
<ReactiveSequence name="root">
32+
<Print value="{interpolated}" />
33+
<Interpolate x0="[1.0, 0.0]" x1="[0.0, 1.0]" out="{interpolated}" />
34+
</ReactiveSequence>
35+
</Sequence>
36+
</BehaviorTree>
37+
38+
</root>
39+
"""
40+
41+
42+
@ports(inputs=["x0", "x1"], outputs=["out"])
43+
class Interpolate(StatefulActionNode):
44+
def on_start(self):
45+
self.t = 0.0
46+
self.x0 = np.asarray(self.get_input("x0"))
47+
self.x1 = np.asarray(self.get_input("x1"))
48+
return NodeStatus.RUNNING
49+
50+
def on_running(self):
51+
if self.t < 1.0:
52+
x = (1.0 - self.t) * self.x0 + self.t * self.x1
53+
self.set_output("out", x)
54+
self.t += 0.1
55+
return NodeStatus.RUNNING
56+
else:
57+
return NodeStatus.SUCCESS
58+
59+
def on_halted(self):
60+
pass
61+
62+
63+
@ports(inputs=["value"])
64+
class Print(SyncActionNode):
65+
def tick(self):
66+
print(self.get_input("value"))
67+
return NodeStatus.SUCCESS
68+
69+
70+
factory = BehaviorTreeFactory()
71+
factory.register(Interpolate)
72+
factory.register(Print)
73+
74+
tree = factory.create_tree_from_text(xml_text)
75+
tree.tick_while_running()

src/python_bindings.cpp

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <memory>
2+
#include <stdexcept>
23

34
#include <pybind11/pybind11.h>
45
#include <pybind11/gil.h>
@@ -43,6 +44,46 @@ class Py_SyncActionNode : public SyncActionNode
4344
}
4445
};
4546

47+
class Py_StatefulActionNode final : public StatefulActionNode
48+
{
49+
public:
50+
Py_StatefulActionNode(const std::string& name, const NodeConfig& config) :
51+
StatefulActionNode(name, config)
52+
{}
53+
54+
NodeStatus onStart() override
55+
{
56+
py::gil_scoped_acquire gil;
57+
return py::get_overload(this, "on_start")().cast<NodeStatus>();
58+
}
59+
60+
NodeStatus onRunning() override
61+
{
62+
py::gil_scoped_acquire gil;
63+
return py::get_overload(this, "on_running")().cast<NodeStatus>();
64+
}
65+
66+
void onHalted() override
67+
{
68+
py::gil_scoped_acquire gil;
69+
py::get_overload(this, "on_halted")();
70+
}
71+
72+
// TODO: Share these duplicated methods with other node types.
73+
py::object Py_getInput(const std::string& name)
74+
{
75+
py::object obj;
76+
getInput(name, obj);
77+
return obj;
78+
}
79+
80+
// TODO: Share these duplicated methods with other node types.
81+
void Py_setOutput(const std::string& name, const py::object& value)
82+
{
83+
setOutput(name, value);
84+
}
85+
};
86+
4687
// Add a conversion specialization from string values into general py::objects
4788
// by evaluating as a Python expression.
4889
template <>
@@ -112,8 +153,18 @@ PYBIND11_MODULE(btpy_cpp, m)
112153
// is destroyed, then the object will live forever.
113154
obj.inc_ref();
114155

115-
return std::unique_ptr<Py_SyncActionNode>(
116-
obj.cast<Py_SyncActionNode*>());
156+
if (py::isinstance<Py_SyncActionNode>(obj))
157+
{
158+
return std::unique_ptr<TreeNode>(obj.cast<Py_SyncActionNode*>());
159+
}
160+
else if (py::isinstance<Py_StatefulActionNode>(obj))
161+
{
162+
return std::unique_ptr<TreeNode>(obj.cast<Py_StatefulActionNode*>());
163+
}
164+
else
165+
{
166+
throw std::runtime_error("invalid node type of " + name);
167+
}
117168
});
118169
})
119170
.def("create_tree_from_text",
@@ -142,6 +193,14 @@ PYBIND11_MODULE(btpy_cpp, m)
142193
.def("tick", &Py_SyncActionNode::tick)
143194
.def("get_input", &Py_SyncActionNode::Py_getInput)
144195
.def("set_output", &Py_SyncActionNode::Py_setOutput);
196+
197+
py::class_<Py_StatefulActionNode>(m, "StatefulActionNode")
198+
.def(py::init<const std::string&, const NodeConfig&>())
199+
.def("on_start", &Py_StatefulActionNode::onStart)
200+
.def("on_running", &Py_StatefulActionNode::onRunning)
201+
.def("on_halted", &Py_StatefulActionNode::onHalted)
202+
.def("get_input", &Py_StatefulActionNode::Py_getInput)
203+
.def("set_output", &Py_StatefulActionNode::Py_setOutput);
145204
}
146205

147206
} // namespace BT

0 commit comments

Comments
 (0)