Skip to content

Commit b163653

Browse files
Add example SAM 2 behavior and objective (#23)
* Add SAM2 behavior and ONNX models --------- Signed-off-by: Paul Gesel <[email protected]> Co-authored-by: Griswald Brooks <[email protected]>
1 parent 3cb843a commit b163653

File tree

8 files changed

+271
-9
lines changed

8 files changed

+271
-9
lines changed

.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.onnx filter=lfs diff=lfs merge=lfs -text

src/example_behaviors/CMakeLists.txt

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,24 @@ example_interfaces)
1515
foreach(package IN ITEMS ${THIS_PACKAGE_INCLUDE_DEPENDS})
1616
find_package(${package} REQUIRED)
1717
endforeach()
18+
find_package(moveit_pro_ml REQUIRED)
1819

1920
add_library(
2021
example_behaviors
2122
SHARED
2223
src/add_two_ints_service_client.cpp
2324
src/convert_mtc_solution_to_joint_trajectory.cpp
2425
src/delayed_message.cpp
25-
src/get_string_from_topic.cpp
2626
src/fibonacci_action_client.cpp
27+
src/get_string_from_topic.cpp
2728
src/hello_world.cpp
29+
src/ndt_registration.cpp
2830
src/publish_color_rgba.cpp
31+
src/ransac_registration.cpp
32+
src/sam2_segmentation.cpp
2933
src/setup_mtc_pick_from_pose.cpp
3034
src/setup_mtc_place_from_pose.cpp
3135
src/setup_mtc_wave_hand.cpp
32-
src/ndt_registration.cpp
33-
src/ransac_registration.cpp
3436
src/register_behaviors.cpp)
3537
target_include_directories(
3638
example_behaviors
@@ -39,6 +41,7 @@ target_include_directories(
3941
PRIVATE ${PCL_INCLUDE_DIRS})
4042
ament_target_dependencies(example_behaviors
4143
${THIS_PACKAGE_INCLUDE_DEPENDS})
44+
target_link_libraries(example_behaviors onnx_sam2)
4245

4346
# Install Libraries
4447
install(
@@ -50,7 +53,7 @@ install(
5053
INCLUDES
5154
DESTINATION include)
5255

53-
install(DIRECTORY config DESTINATION share/${PROJECT_NAME})
56+
install(DIRECTORY config models DESTINATION share/${PROJECT_NAME})
5457

5558
if(BUILD_TESTING)
5659
moveit_pro_behavior_test(example_behaviors)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#pragma once
2+
3+
#include <future>
4+
#include <memory>
5+
#include <string>
6+
7+
#include <moveit_pro_ml/onnx_sam2.hpp>
8+
#include <moveit_pro_ml/onnx_sam2_types.hpp>
9+
#include <moveit_studio_behavior_interface/async_behavior_base.hpp>
10+
#include <moveit_studio_vision_msgs/msg/mask2_d.hpp>
11+
#include <sensor_msgs/msg/image.hpp>
12+
#include <tl_expected/expected.hpp>
13+
14+
15+
namespace example_behaviors
16+
{
17+
/**
18+
* @brief Segment an image using the SAM 2 model
19+
*/
20+
class SAM2Segmentation : public moveit_studio::behaviors::AsyncBehaviorBase
21+
{
22+
public:
23+
/**
24+
* @brief Constructor for the SAM2Segmentation behavior.
25+
* @param name The name of a particular instance of this Behavior. This will be set by the behavior tree factory when this Behavior is created within a new behavior tree.
26+
* @param config This contains runtime configuration info for this Behavior, such as the mapping between the Behavior's data ports on the behavior tree's blackboard. This will be set by the behavior tree factory when this Behavior is created within a new behavior tree.
27+
* @details An important limitation is that the members of the base Behavior class are not instantiated until after the initialize() function is called, so these classes should not be used within the constructor.
28+
*/
29+
SAM2Segmentation(const std::string& name, const BT::NodeConfiguration& config,
30+
const std::shared_ptr<moveit_studio::behaviors::BehaviorContext>& shared_resources);
31+
32+
/**
33+
* @brief Implementation of the required providedPorts() function for the Behavior.
34+
* @details The BehaviorTree.CPP library requires that Behaviors must implement a static function named providedPorts() which defines their input and output ports. If the Behavior does not use any ports, this function must return an empty BT::PortsList.
35+
* This function returns a list of ports with their names and port info, which is used internally by the behavior tree.
36+
* @return List of ports for the behavior.
37+
*/
38+
static BT::PortsList providedPorts();
39+
40+
/**
41+
* @brief Implementation of the metadata() function for displaying metadata, such as Behavior description and
42+
* subcategory, in the MoveIt Studio Developer Tool.
43+
* @return A BT::KeyValueVector containing the Behavior metadata.
44+
*/
45+
static BT::KeyValueVector metadata();
46+
47+
protected:
48+
tl::expected<bool, std::string> doWork() override;
49+
50+
51+
private:
52+
std::unique_ptr<moveit_pro_ml::SAM2> sam2_;
53+
moveit_pro_ml::ONNXImage onnx_image_;
54+
sensor_msgs::msg::Image mask_image_msg_;
55+
moveit_studio_vision_msgs::msg::Mask2D mask_msg_;
56+
57+
/** @brief Classes derived from AsyncBehaviorBase must implement getFuture() so that it returns a shared_future class member */
58+
std::shared_future<tl::expected<bool, std::string>>& getFuture() override
59+
{
60+
return future_;
61+
}
62+
63+
/** @brief Classes derived from AsyncBehaviorBase must have this shared_future as a class member */
64+
std::shared_future<tl::expected<bool, std::string>> future_;
65+
66+
};
67+
} // namespace sam2_segmentation
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:1f448cdb479e6ec14e61c4756138eb4081ce7f8a11ca43a0a24856d5e8b61b6f
3+
size 20665365
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:c99ab89a38385753aff7ea9155f0808ad5535bc55ea2a49320254e39e4011630
3+
size 889364590

src/example_behaviors/src/register_behaviors.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,19 @@
22
#include <moveit_studio_behavior_interface/behavior_context.hpp>
33
#include <moveit_studio_behavior_interface/shared_resources_node_loader.hpp>
44

5-
#include <example_behaviors/hello_world.hpp>
5+
#include <example_behaviors/add_two_ints_service_client.hpp>
66
#include <example_behaviors/convert_mtc_solution_to_joint_trajectory.hpp>
77
#include <example_behaviors/delayed_message.hpp>
8-
#include <example_behaviors/setup_mtc_wave_hand.hpp>
9-
#include <example_behaviors/add_two_ints_service_client.hpp>
108
#include <example_behaviors/fibonacci_action_client.hpp>
119
#include <example_behaviors/get_string_from_topic.hpp>
10+
#include <example_behaviors/hello_world.hpp>
11+
#include <example_behaviors/ndt_registration.hpp>
1212
#include <example_behaviors/publish_color_rgba.hpp>
13+
#include <example_behaviors/ransac_registration.hpp>
14+
#include <example_behaviors/sam2_segmentation.hpp>
1315
#include <example_behaviors/setup_mtc_pick_from_pose.hpp>
1416
#include <example_behaviors/setup_mtc_place_from_pose.hpp>
15-
#include <example_behaviors/ndt_registration.hpp>
16-
#include <example_behaviors/ransac_registration.hpp>
17+
#include <example_behaviors/setup_mtc_wave_hand.hpp>
1718

1819
#include <pluginlib/class_list_macros.hpp>
1920

@@ -35,6 +36,7 @@ class ExampleBehaviorsLoader : public moveit_studio::behaviors::SharedResourcesN
3536
moveit_studio::behaviors::registerBehavior<FibonacciActionClient>(factory, "FibonacciActionClient",
3637
shared_resources);
3738
moveit_studio::behaviors::registerBehavior<PublishColorRGBA>(factory, "PublishColorRGBA", shared_resources);
39+
moveit_studio::behaviors::registerBehavior<SAM2Segmentation>(factory, "SAM2Segmentation", shared_resources);
3840
moveit_studio::behaviors::registerBehavior<SetupMtcPickFromPose>(factory, "SetupMtcPickFromPose", shared_resources);
3941
moveit_studio::behaviors::registerBehavior<SetupMtcPlaceFromPose>(factory, "SetupMtcPlaceFromPose",
4042
shared_resources);
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
#include <future>
2+
#include <memory>
3+
#include <string>
4+
5+
#include <ament_index_cpp/get_package_share_directory.hpp>
6+
#include <example_behaviors/sam2_segmentation.hpp>
7+
#include <geometry_msgs/msg/point_stamped.hpp>
8+
#include <moveit_pro_ml/onnx_sam2.hpp>
9+
#include <moveit_studio_behavior_interface/async_behavior_base.hpp>
10+
#include <moveit_studio_behavior_interface/get_required_ports.hpp>
11+
#include <moveit_studio_vision_msgs/msg/mask2_d.hpp>
12+
#include <sensor_msgs/msg/image.hpp>
13+
#include <tl_expected/expected.hpp>
14+
15+
namespace
16+
{
17+
constexpr auto kPortImage = "image";
18+
constexpr auto kPortImageDefault = "{image}";
19+
constexpr auto kPortPoint = "pixel_coords";
20+
constexpr auto kPortPointDefault = "{pixel_coords}";
21+
constexpr auto kPortMasks = "masks2d";
22+
constexpr auto kPortMasksDefault = "{masks2d}";
23+
24+
constexpr auto kImageInferenceWidth = 1024;
25+
constexpr auto kImageInferenceHeight = 1024;
26+
} // namespace
27+
28+
namespace example_behaviors
29+
{
30+
// Convert a ROS image message to the ONNX image format used by the SAM 2 model
31+
void set_onnx_image_from_ros_image(const sensor_msgs::msg::Image& image_msg,
32+
moveit_pro_ml::ONNXImage& onnx_image)
33+
{
34+
onnx_image.shape = {1, image_msg.height, image_msg.width, 3};
35+
onnx_image.data.resize(image_msg.height * image_msg.width * 3);
36+
const int stride = image_msg.encoding != "rgb8" ? 3: 4;
37+
for (size_t i = 0; i < onnx_image.data.size(); i+=stride)
38+
{
39+
onnx_image.data[i] = static_cast<float>(image_msg.data[i]) / 255.0f;
40+
onnx_image.data[i+1] = static_cast<float>(image_msg.data[i+1]) / 255.0f;
41+
onnx_image.data[i+2] = static_cast<float>(image_msg.data[i+2]) / 255.0f;
42+
}
43+
}
44+
45+
// Converts a single channel ONNX image mask to a ROS mask message.
46+
void set_ros_mask_from_onnx_mask(const moveit_pro_ml::ONNXImage& onnx_image, sensor_msgs::msg::Image& mask_image_msg, moveit_studio_vision_msgs::msg::Mask2D& mask_msg)
47+
{
48+
mask_image_msg.height = static_cast<uint32_t>(onnx_image.shape[0]);
49+
mask_image_msg.width = static_cast<uint32_t>(onnx_image.shape[1]);
50+
mask_image_msg.encoding = "mono8";
51+
mask_image_msg.data.resize(mask_image_msg.height * mask_image_msg.width);
52+
mask_image_msg.step = mask_image_msg.width;
53+
for (size_t i = 0; i < onnx_image.data.size(); ++i)
54+
{
55+
mask_image_msg.data[i] = onnx_image.data[i] > 0.5 ? 255: 0;
56+
}
57+
mask_msg.pixels = mask_image_msg;
58+
mask_msg.x = 0;
59+
mask_msg.y = 0;
60+
}
61+
62+
SAM2Segmentation::SAM2Segmentation(const std::string& name, const BT::NodeConfiguration& config,
63+
const std::shared_ptr<moveit_studio::behaviors::BehaviorContext>& shared_resources)
64+
: moveit_studio::behaviors::AsyncBehaviorBase(name, config, shared_resources)
65+
{
66+
67+
const std::filesystem::path package_path = ament_index_cpp::get_package_share_directory("example_behaviors");
68+
const std::filesystem::path encoder_onnx_file = package_path / "models" / "sam2_hiera_large_encoder.onnx";
69+
const std::filesystem::path decoder_onnx_file = package_path / "models" / "decoder.onnx";
70+
sam2_ = std::make_unique<moveit_pro_ml::SAM2>(encoder_onnx_file, decoder_onnx_file);
71+
}
72+
73+
BT::PortsList SAM2Segmentation::providedPorts()
74+
{
75+
return {
76+
BT::InputPort<sensor_msgs::msg::Image>(kPortImage, kPortImageDefault,
77+
"The Image to run segmentation on."),
78+
BT::InputPort<std::vector<geometry_msgs::msg::PointStamped>>(kPortPoint, kPortPointDefault,
79+
"The input points, as a vector of <code>geometry_msgs/PointStamped</code> messages to be used for segmentation."),
80+
81+
BT::OutputPort<std::vector<moveit_studio_vision_msgs::msg::Mask2D>>(kPortMasks, kPortMasksDefault,
82+
"The masks contained in a vector of <code>moveit_studio_vision_msgs::msg::Mask2D</code> messages.")
83+
};
84+
}
85+
86+
tl::expected<bool, std::string> SAM2Segmentation::doWork()
87+
{
88+
const auto ports = moveit_studio::behaviors::getRequiredInputs(getInput<sensor_msgs::msg::Image>(kPortImage),
89+
getInput<std::vector<
90+
geometry_msgs::msg::PointStamped>>(kPortPoint));
91+
92+
// Check that all required input data ports were set.
93+
if (!ports.has_value())
94+
{
95+
auto error_message = fmt::format("Failed to get required values from input data ports:\n{}", ports.error());
96+
return tl::make_unexpected(error_message);
97+
}
98+
const auto& [image_msg, points_2d] = ports.value();
99+
100+
if (image_msg.encoding != "rgb8" && image_msg.encoding != "rgba8")
101+
{
102+
auto error_message = fmt::format("Invalid image message format. Expected `(rgb8, rgba8)` got :\n{}", image_msg.encoding);
103+
return tl::make_unexpected(error_message);
104+
}
105+
106+
// Create ONNX formatted image tensor from ROS image
107+
set_onnx_image_from_ros_image(image_msg, onnx_image_);
108+
109+
std::vector<moveit_pro_ml::PointPrompt> point_prompts;
110+
for (auto const& point : points_2d)
111+
{
112+
// Assume all points are the same label
113+
point_prompts.push_back({{kImageInferenceWidth*static_cast<float>(point.point.x), kImageInferenceHeight*static_cast<float>(point.point.y)}, {1.0f}});
114+
}
115+
116+
try
117+
{
118+
const auto masks = sam2_->predict(onnx_image_, point_prompts);
119+
120+
mask_image_msg_.header = image_msg.header;
121+
set_ros_mask_from_onnx_mask(masks, mask_image_msg_, mask_msg_);
122+
123+
setOutput<std::vector<moveit_studio_vision_msgs::msg::Mask2D>>(kPortMasks, {mask_msg_});
124+
}
125+
catch (const std::invalid_argument& e)
126+
{
127+
return tl::make_unexpected(fmt::format("Invalid argument: {}", e.what()));
128+
}
129+
130+
return true;
131+
}
132+
133+
BT::KeyValueVector SAM2Segmentation::metadata()
134+
{
135+
return {
136+
{
137+
"description",
138+
"Segments a ROS image message using the provided points represented as a vector of <code>geometry_msgs/PointStamped</code> messages."
139+
}
140+
};
141+
}
142+
} // namespace sam2_segmentation
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
<?xml version="1.0" encoding="utf-8" ?>
2+
<root
3+
BTCPP_format="4"
4+
main_tree_to_execute="Segment Point Cloud from Clicked Point"
5+
>
6+
<!--//////////-->
7+
<BehaviorTree
8+
ID="Segment Point Cloud from Clicked Point"
9+
_description="Captures a point cloud and requests the user to click an object in the image to be segmented. The point cloud is then filtered to only include the selected object."
10+
_favorite="true"
11+
>
12+
<Control ID="Sequence">
13+
<Action ID="ClearSnapshot" />
14+
<Action ID="GetImage" topic_name="/wrist_camera/color" />
15+
<Action
16+
ID="GetPointsFromUser"
17+
point_prompts="Select the object to be segmented;"
18+
point_names="Point1;"
19+
view_name="/wrist_camera/color"
20+
/>
21+
<Action ID="SAM2Segmentation" />
22+
<Action ID="GetPointCloud" topic_name="/wrist_camera/points" />
23+
<Action
24+
ID="GetCameraInfo"
25+
topic_name="/wrist_camera/camera_info"
26+
message_out="{camera_info}"
27+
timeout_sec="5.000000"
28+
/>
29+
<Action ID="GetMasks3DFromMasks2D" />
30+
<Decorator ID="ForEachMask3D" vector_in="{masks3d}" out="{mask3d}">
31+
<Action ID="GetPointCloudFromMask3D" point_cloud="{point_cloud}" />
32+
</Decorator>
33+
<Action ID="SendPointCloudToUI" point_cloud="{point_cloud_fragment}" />
34+
<Action ID="PublishPointCloud" point_cloud="{point_cloud_fragment}" />
35+
<Action ID="SwitchUIPrimaryView" />
36+
</Control>
37+
</BehaviorTree>
38+
<TreeNodesModel>
39+
<SubTree ID="Segment Point Cloud from Clicked Point" />
40+
</TreeNodesModel>
41+
</root>

0 commit comments

Comments
 (0)