forked from PickNikRobotics/moveit_pro_empty_ws
-
Notifications
You must be signed in to change notification settings - Fork 6
Add example SAM 2 behavior and objective #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
c44583a
add ONNX models
pac48 ac8d439
Add SAM2 behavior
pac48 e0a98f7
cleanup
pac48 0d1dd3f
Update src/example_behaviors/include/example_behaviors/sam2_segmentat…
pac48 3e237eb
Update src/example_behaviors/src/sam2_segmentation.cpp
pac48 994dee9
Update src/example_behaviors/src/sam2_segmentation.cpp
pac48 8f43342
Update src/example_behaviors/src/sam2_segmentation.cpp
pac48 fc6dc1a
Update src/example_behaviors/src/sam2_segmentation.cpp
pac48 5c80b9a
Update src/example_behaviors/include/example_behaviors/sam2_segmentat…
pac48 cbb3267
Update src/example_behaviors/src/sam2_segmentation.cpp
pac48 08339ff
Update src/example_behaviors/src/sam2_segmentation.cpp
pac48 28f921d
Update src/example_behaviors/src/sam2_segmentation.cpp
pac48 2c3ec18
address feedback
pac48 1f89a5b
add headers and fix comments
pac48 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| *.onnx filter=lfs diff=lfs merge=lfs -text |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
76 changes: 76 additions & 0 deletions
76
src/example_behaviors/include/example_behaviors/sam2_segmentation.hpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| #pragma once | ||
|
|
||
|
|
||
| #include <moveit_studio_behavior_interface/async_behavior_base.hpp> | ||
| #include <moveit_pro_ml/onnx_sam2.hpp> | ||
| #include <sensor_msgs/msg/image.hpp> | ||
| #include <moveit_studio_vision_msgs/msg/mask2_d.hpp> | ||
| #include <fmt/format.h> | ||
griswaldbrooks marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| namespace example_behaviors | ||
| { | ||
| /** | ||
| * @brief Segment an image using the SAM 2 model | ||
| */ | ||
| class SAM2Segmentation : public moveit_studio::behaviors::AsyncBehaviorBase | ||
| { | ||
| public: | ||
| /** | ||
| * @brief Constructor for the sam2_segmentation behavior. | ||
| * @param name The name of a particular instance of this Behavior. This will be set by the behavior tree factory when this Behavior is created within a new behavior tree. | ||
| * @param config This contains runtime configuration info for this Behavior, such as the mapping between the Behavior's data ports on the behavior tree's blackboard. This will be set by the behavior tree factory when this Behavior is created within a new behavior tree. | ||
| * @details An important limitation is that the members of the base Behavior class are not instantiated until after the initialize() function is called, so these classes should not be used within the constructor. | ||
| */ | ||
| SAM2Segmentation(const std::string& name, const BT::NodeConfiguration& config, | ||
| const std::shared_ptr<moveit_studio::behaviors::BehaviorContext>& shared_resources); | ||
|
|
||
| /** | ||
| * @brief Implementation of the required providedPorts() function for the sam2_segmentation Behavior. | ||
| * @details The BehaviorTree.CPP library requires that Behaviors must implement a static function named providedPorts() which defines their input and output ports. If the Behavior does not use any ports, this function must return an empty BT::PortsList. | ||
| * This function returns a list of ports with their names and port info, which is used internally by the behavior tree. | ||
| * @return sam2_segmentation does not use expose any ports, so this function returns an empty list. | ||
griswaldbrooks marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| */ | ||
| static BT::PortsList providedPorts(); | ||
|
|
||
| /** | ||
| * @brief Implementation of the metadata() function for displaying metadata, such as Behavior description and | ||
| * subcategory, in the MoveIt Studio Developer Tool. | ||
| * @return A BT::KeyValueVector containing the Behavior metadata. | ||
| */ | ||
| static BT::KeyValueVector metadata(); | ||
|
|
||
| protected: | ||
| tl::expected<bool, std::string> doWork() override; | ||
|
|
||
|
|
||
| private: | ||
| /** | ||
| * @brief Convert a ROS image message to the ONNX image format used by the SAM 2 model. | ||
| * @param image_msg The ROS message to be converted | ||
| * @param onnx_image The ONNX image | ||
| */ | ||
| void set_onnx_image_from_ros_image(const sensor_msgs::msg::Image& image_msg, moveit_pro_ml::ONNXImage& onnx_image); | ||
|
|
||
| /** | ||
| * @brief Convert an ONNX image to a ROS image message. | ||
| * @param onnx_image The ONNX image to be converted | ||
| * @param image_msg The ROS message | ||
| */ | ||
| void set_ros_image_from_onnx_image(const moveit_pro_ml::ONNXImage& onnx_image, sensor_msgs::msg::Image& image_msg); | ||
griswaldbrooks marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| std::shared_ptr<moveit_pro_ml::SAM2> sam2_; | ||
pac48 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| moveit_pro_ml::ONNXImage onnx_image_; | ||
| sensor_msgs::msg::Image mask_image_msg_; | ||
| moveit_studio_vision_msgs::msg::Mask2D mask_; | ||
|
|
||
| /** @brief Classes derived from AsyncBehaviorBase must implement getFuture() so that it returns a shared_future class member */ | ||
| std::shared_future<tl::expected<bool, std::string>>& getFuture() override | ||
| { | ||
| return future_; | ||
| } | ||
|
|
||
| /** @brief Classes derived from AsyncBehaviorBase must have this shared_future as a class member */ | ||
| std::shared_future<tl::expected<bool, std::string>> future_; | ||
|
|
||
| }; | ||
| } // namespace sam2_segmentation | ||
Git LFS file not shown
Git LFS file not shown
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,137 @@ | ||
| #include <spdlog/spdlog.h> | ||
griswaldbrooks marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| #include <example_behaviors/sam2_segmentation.hpp> | ||
|
|
||
pac48 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| #include <geometry_msgs/msg/point_stamped.hpp> | ||
| #include <moveit_studio_behavior_interface/get_required_ports.hpp> | ||
| #include <ament_index_cpp/get_package_share_directory.hpp> | ||
|
|
||
| namespace | ||
| { | ||
| constexpr auto kPortImage = "image"; | ||
| constexpr auto kPortImageDefault = "{image}"; | ||
| constexpr auto kPortPoint = "pixel_coords"; | ||
| constexpr auto kPortPointDefault = "{pixel_coords}"; | ||
| constexpr auto kPortMasks = "masks2d"; | ||
| constexpr auto kPortMasksDefault = "{masks2d}"; | ||
|
|
||
| constexpr auto kImageInferenceWidth = 1024; | ||
| constexpr auto kImageInferenceHeight = 1024; | ||
| } // namespace | ||
|
|
||
| namespace example_behaviors | ||
| { | ||
| SAM2Segmentation::SAM2Segmentation(const std::string& name, const BT::NodeConfiguration& config, | ||
| const std::shared_ptr<moveit_studio::behaviors::BehaviorContext>& shared_resources) | ||
| : moveit_studio::behaviors::AsyncBehaviorBase(name, config, shared_resources) | ||
| { | ||
|
|
||
| std::filesystem::path package_path = ament_index_cpp::get_package_share_directory("example_behaviors"); | ||
pac48 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| const std::filesystem::path encoder_onnx_file = package_path / "models" / "sam2_hiera_large_encoder.onnx"; | ||
| const std::filesystem::path decoder_onnx_file = package_path / "models" / "decoder.onnx"; | ||
| sam2_ = std::make_shared<moveit_pro_ml::SAM2>(encoder_onnx_file, decoder_onnx_file); | ||
| } | ||
|
|
||
| BT::PortsList SAM2Segmentation::providedPorts() | ||
| { | ||
| return { | ||
| BT::InputPort<sensor_msgs::msg::Image>(kPortImage, kPortImageDefault, | ||
| "The Image to run segmentation on."), | ||
| BT::InputPort<std::vector<geometry_msgs::msg::PointStamped>>(kPortPoint, kPortPointDefault, | ||
| "The input points, as a vector of <code>geometry_msgs/PointStamped</code> messages to be used for segmentation."), | ||
|
|
||
| BT::OutputPort<std::vector<moveit_studio_vision_msgs::msg::Mask2D>>(kPortMasks, kPortMasksDefault, | ||
griswaldbrooks marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "The masks contained in a vector of <code>moveit_studio_vision_msgs::msg::Mask2D</code> messages.") | ||
|
|
||
|
|
||
| }; | ||
| } | ||
|
|
||
| void SAM2Segmentation::set_onnx_image_from_ros_image(const sensor_msgs::msg::Image& image_msg, | ||
| moveit_pro_ml::ONNXImage& onnx_image) | ||
| { | ||
| onnx_image.shape = {1, image_msg.height, image_msg.width, 3}; | ||
| onnx_image.data.resize(image_msg.height * image_msg.width * 3); | ||
| int stride = image_msg.encoding != "rgb8" ? 3: 4; | ||
pac48 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| for (size_t i = 0; i < onnx_image.data.size(); i+=stride) | ||
| { | ||
| onnx_image.data[i] = static_cast<float>(image_msg.data[i]) / 255.0f; | ||
| onnx_image.data[i+1] = static_cast<float>(image_msg.data[i+1]) / 255.0f; | ||
| onnx_image.data[i+2] = static_cast<float>(image_msg.data[i+2]) / 255.0f; | ||
| } | ||
| } | ||
|
|
||
|
|
||
| void SAM2Segmentation::set_ros_image_from_onnx_image(const moveit_pro_ml::ONNXImage& onnx_image, sensor_msgs::msg::Image& image_msg) | ||
| { | ||
| image_msg.height = static_cast<uint32_t>(onnx_image.shape[0]); | ||
| image_msg.width = static_cast<uint32_t>(onnx_image.shape[1]); | ||
| image_msg.encoding = "mono8"; | ||
griswaldbrooks marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| image_msg.data.resize(image_msg.height * image_msg.width); | ||
| image_msg.step = image_msg.width; | ||
| for (size_t i = 0; i < onnx_image.data.size(); ++i) | ||
| { | ||
| image_msg.data[i] = onnx_image.data[i] > 0.5 ? 255: 0; | ||
| } | ||
| } | ||
|
|
||
| tl::expected<bool, std::string> SAM2Segmentation::doWork() | ||
| { | ||
| const auto ports = moveit_studio::behaviors::getRequiredInputs(getInput<sensor_msgs::msg::Image>(kPortImage), | ||
| getInput<std::vector< | ||
| geometry_msgs::msg::PointStamped>>(kPortPoint)); | ||
|
|
||
| // Check that all required input data ports were set. | ||
| if (!ports.has_value()) | ||
| { | ||
| auto error_message = fmt::format("Failed to get required values from input data ports:\n{}", ports.error()); | ||
| return tl::make_unexpected(error_message); | ||
| } | ||
| const auto& [image_msg, points_2d] = ports.value(); | ||
|
|
||
| if (image_msg.encoding != "rgb8" && image_msg.encoding != "rgba8") | ||
| { | ||
| auto error_message = fmt::format("Invalid image message format. Expected `(rgb8, rgba8)` got :\n{}", image_msg.encoding); | ||
| return tl::make_unexpected(error_message); | ||
| } | ||
|
|
||
| // Create ONNX formatted image tensor from ROS image | ||
| set_onnx_image_from_ros_image(image_msg, onnx_image_); | ||
|
|
||
| std::vector<moveit_pro_ml::PointPrompt> point_prompts; | ||
| for (auto const& point : points_2d) | ||
| { | ||
| // Assume all point are the same label | ||
pac48 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| point_prompts.push_back({{kImageInferenceWidth*static_cast<float>(point.point.x), kImageInferenceHeight*static_cast<float>(point.point.y)}, {1.0f}}); | ||
| } | ||
|
|
||
| try | ||
| { | ||
| auto masks = sam2_->predict(onnx_image_, point_prompts); | ||
pac48 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| mask_image_msg_.header = image_msg.header; | ||
| set_ros_image_from_onnx_image(masks, mask_image_msg_); | ||
griswaldbrooks marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| mask_.pixels = mask_image_msg_; | ||
| mask_.x = 0; | ||
| mask_.y = 0; | ||
| setOutput<std::vector<moveit_studio_vision_msgs::msg::Mask2D>>(kPortMasks, {mask_}); | ||
| } | ||
| catch (std::invalid_argument& e) | ||
pac48 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| { | ||
| auto error_message = fmt::format("Invalid argument: {}", e.what()); | ||
| return tl::make_unexpected(error_message); | ||
pac48 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| return true; | ||
| } | ||
|
|
||
| BT::KeyValueVector SAM2Segmentation::metadata() | ||
| { | ||
| return { | ||
| { | ||
| "description", | ||
| "Segments a ROS image message using the provided points represented as a vector of <code>geometry_msgs/PointStamped</code> messages." | ||
| } | ||
| }; | ||
| } | ||
| } // namespace sam2_segmentation | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| <?xml version="1.0" encoding="utf-8" ?> | ||
| <root | ||
| BTCPP_format="4" | ||
| main_tree_to_execute="Segment Point Cloud from Clicked Point" | ||
| > | ||
| <!--//////////--> | ||
| <BehaviorTree | ||
| ID="Segment Point Cloud from Clicked Point" | ||
| _description="Captures a point cloud and requests the user to click an object in the image to be segmented. The point cloud is then filtered to only include the selected object." | ||
| _favorite="true" | ||
| > | ||
| <Control ID="Sequence"> | ||
| <Action ID="ClearSnapshot" /> | ||
| <Action ID="GetImage" topic_name="/wrist_camera/color" /> | ||
| <Action | ||
| ID="GetPointsFromUser" | ||
| point_prompts="Select the object to be segmented;" | ||
| point_names="Point1;" | ||
| view_name="/wrist_camera/color" | ||
| /> | ||
| <Action ID="SAM2Segmentation" /> | ||
| <Action ID="GetPointCloud" topic_name="/wrist_camera/points" /> | ||
| <Action | ||
| ID="GetCameraInfo" | ||
| topic_name="/wrist_camera/camera_info" | ||
| message_out="{camera_info}" | ||
| timeout_sec="5.000000" | ||
| /> | ||
| <Action ID="GetMasks3DFromMasks2D" /> | ||
| <Decorator ID="ForEachMask3D" vector_in="{masks3d}" out="{mask3d}"> | ||
| <Action ID="GetPointCloudFromMask3D" point_cloud="{point_cloud}" /> | ||
| </Decorator> | ||
| <Action ID="SendPointCloudToUI" point_cloud="{point_cloud_fragment}" /> | ||
| <Action ID="PublishPointCloud" point_cloud="{point_cloud_fragment}" /> | ||
| <Action ID="SwitchUIPrimaryView" /> | ||
| </Control> | ||
| </BehaviorTree> | ||
| <TreeNodesModel> | ||
| <SubTree ID="Segment Point Cloud from Clicked Point" /> | ||
| </TreeNodesModel> | ||
| </root> |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.