From ec2e310be27b73e172c77cf8920f68e758dbe247 Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Tue, 2 Sep 2025 14:07:24 +0200 Subject: [PATCH 01/13] feat(temp): debug 3d det --- .../tools/segmentation_tools.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py index 16c6fc2df..a395d986c 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py @@ -291,8 +291,25 @@ def _process_mask( masked_depth_image, intrinsic[0], intrinsic[1], intrinsic[2], intrinsic[3] ) - # TODO: Filter out outliers points = pcd + # publish resulting pointcloud + import time + + from geometry_msgs.msg import Point32 + from sensor_msgs.msg import PointCloud + + msg = PointCloud() + msg.header.frame_id = "egofront_rgbd_camera_depth_optical_frame" + msg.points = [Point32(x=p[0], y=p[1], z=p[2]) for p in points] + pub = self.connector.node.create_publisher( + PointCloud, "/debug/get_grabbing_point_pointcloud", 10 + ) + while True: + self.connector.node.get_logger().info( + f"publishing pointcloud to /debug/get_grabbing_point_pointcloud: {len(msg.points)} points, mean: {np.array(points.mean(axis=0))}." + ) + pub.publish(msg) + time.sleep(0.1) # https://github.com/ycheng517/tabletop-handybot/blob/6d401e577e41ea86529d091b406fbfc936f37a8d/tabletop_handybot/tabletop_handybot/tabletop_handybot_node.py#L413-L424 grasp_z = points[:, 2].max() From 142bb58c788608ae6f0ea482d68b71f3b0ae24fc Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Tue, 2 Sep 2025 16:03:46 +0200 Subject: [PATCH 02/13] feat: all --- src/rai_core/rai/tools/ros2/detection/pcl.py | 606 +++++++++++++++++++ xd.py | 31 + 2 files changed, 637 insertions(+) create mode 100644 src/rai_core/rai/tools/ros2/detection/pcl.py create mode 100644 xd.py diff --git a/src/rai_core/rai/tools/ros2/detection/pcl.py b/src/rai_core/rai/tools/ros2/detection/pcl.py new file mode 100644 index 000000000..c9f59d16f --- /dev/null +++ b/src/rai_core/rai/tools/ros2/detection/pcl.py @@ -0,0 +1,606 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Literal, Optional, cast + +import numpy as np +import sensor_msgs.msg +from numpy.typing import NDArray +from rai_open_set_vision import GDINO_SERVICE_NAME +from rclpy import Future +from rclpy.exceptions import ( + ParameterNotDeclaredException, + ParameterUninitializedException, +) + +from rai.communication.ros2.api import ( + convert_ros_img_to_ndarray, # type: ignore[reportUnknownVariableType] +) +from rai.communication.ros2.connectors import ROS2Connector +from rai.communication.ros2.ros_async import get_future_result +from rai_interfaces.srv import RAIGroundedSam, RAIGroundingDino + + +def depth_to_point_cloud( + depth_image: NDArray[np.float32], fx: float, fy: float, cx: float, cy: float +) -> NDArray[np.float32]: + height, width = depth_image.shape + x_coords = np.arange(width, dtype=np.float32) + y_coords = np.arange(height, dtype=np.float32) + x_grid, y_grid = np.meshgrid(x_coords, y_coords) + z = depth_image + x = (x_grid - float(cx)) * z / float(fx) + y = (y_grid - float(cy)) * z / float(fy) + points = np.stack((x, y, z), axis=-1).reshape(-1, 3) + points = points[points[:, 2] > 0] + return points.astype(np.float32, copy=False) + + +class PointCloudFromSegmentation: + """Generate a masked point cloud for an object and transform it to a target frame. + + Configure with source/target TF frames and ROS2 topics. Call run(object_name) to + get an Nx3 numpy array of points [X, Y, Z] expressed in the target frame. + """ + + connector: ROS2Connector + camera_topic: str + depth_topic: str + camera_info_topic: str + source_frame: str + target_frame: str + + box_threshold: float = 0.35 + text_threshold: float = 0.45 + + def __init__( + self, + *, + connector: ROS2Connector, + camera_topic: str, + depth_topic: str, + camera_info_topic: str, + source_frame: str, + target_frame: str, + box_threshold: float = 0.35, + text_threshold: float = 0.45, + ) -> None: + self.connector = connector + self.camera_topic = camera_topic + self.depth_topic = depth_topic + self.camera_info_topic = camera_info_topic + self.source_frame = source_frame + self.target_frame = target_frame + self.box_threshold = box_threshold + self.text_threshold = text_threshold + + # --------------------- ROS helpers --------------------- + def _get_image_message(self, topic: str) -> sensor_msgs.msg.Image: + msg = self.connector.receive_message(topic).payload + if isinstance(msg, sensor_msgs.msg.Image): + return msg + raise TypeError("Received wrong message type for Image") + + def _get_camera_info_message(self, topic: str) -> sensor_msgs.msg.CameraInfo: + for _ in range(3): + msg = self.connector.receive_message(topic, timeout_sec=3.0).payload + if isinstance(msg, sensor_msgs.msg.CameraInfo): + return msg + self.connector.node.get_logger().warn( # type: ignore[reportUnknownMemberType] + "Received wrong CameraInfo message type. Retrying..." + ) + raise RuntimeError("Failed to receive correct CameraInfo after 3 attempts") + + def _get_intrinsic_from_camera_info( + self, camera_info: sensor_msgs.msg.CameraInfo + ) -> tuple[float, float, float, float]: + k = camera_info.k # type: ignore[reportUnknownMemberType] + fx = float(k[0]) + fy = float(k[4]) + cx = float(k[2]) + cy = float(k[5]) + return fx, fy, cx, cy + + def _call_gdino_node( + self, camera_img_message: sensor_msgs.msg.Image, object_name: str + ) -> Future: + cli = self.connector.node.create_client(RAIGroundingDino, GDINO_SERVICE_NAME) # type: ignore[reportUnknownMemberType] + while not cli.wait_for_service(timeout_sec=1.0): + self.connector.node.get_logger().info( # type: ignore[reportUnknownMemberType] + f"service {GDINO_SERVICE_NAME} not available, waiting again..." + ) + req = RAIGroundingDino.Request() + req.source_img = camera_img_message + req.classes = object_name + req.box_threshold = self.box_threshold + req.text_threshold = self.text_threshold + return cli.call_async(req) + + def _call_gsam_node( + self, camera_img_message: sensor_msgs.msg.Image, data: RAIGroundingDino.Response + ) -> Future: + cli = self.connector.node.create_client(RAIGroundedSam, "grounded_sam_segment") # type: ignore[reportUnknownMemberType] + while not cli.wait_for_service(timeout_sec=1.0): + self.connector.node.get_logger().info( # type: ignore[reportUnknownMemberType] + "service grounded_sam_segment not available, waiting again..." + ) + req = RAIGroundedSam.Request() + req.detections = data.detections # type: ignore[reportUnknownMemberType] + req.source_img = camera_img_message + return cli.call_async(req) + + # --------------------- Geometry helpers --------------------- + @staticmethod + def _quaternion_to_rotation_matrix( + qx: float, qy: float, qz: float, qw: float + ) -> NDArray[np.float64]: + xx = qx * qx + yy = qy * qy + zz = qz * qz + xy = qx * qy + xz = qx * qz + yz = qy * qz + wx = qw * qx + wy = qw * qy + wz = qw * qz + + return np.array( + [ + [1.0 - 2.0 * (yy + zz), 2.0 * (xy - wz), 2.0 * (xz + wy)], + [2.0 * (xy + wz), 1.0 - 2.0 * (xx + zz), 2.0 * (yz - wx)], + [2.0 * (xz - wy), 2.0 * (yz + wx), 1.0 - 2.0 * (xx + yy)], + ], + dtype=np.float64, + ) + + def _transform_points_source_to_target( + self, points_xyz: NDArray[np.float32] + ) -> NDArray[np.float64]: + if points_xyz.size == 0: + return points_xyz.reshape(0, 3).astype(np.float64) + + transform = self.connector.get_transform(self.target_frame, self.source_frame) + t = transform.transform.translation # type: ignore[reportUnknownMemberType] + r = transform.transform.rotation # type: ignore[reportUnknownMemberType] + qw = float(r.w) # type: ignore[reportUnknownMemberType] + qx = float(r.x) # type: ignore[reportUnknownMemberType] + qy = float(r.y) # type: ignore[reportUnknownMemberType] + qz = float(r.z) # type: ignore[reportUnknownMemberType] + rotation_matrix = self._quaternion_to_rotation_matrix(qx, qy, qz, qw) + translation = np.array([float(t.x), float(t.y), float(t.z)], dtype=np.float64) # type: ignore[reportUnknownMemberType] + + return (points_xyz.astype(np.float64) @ rotation_matrix.T) + translation + + # --------------------- Public API --------------------- + def run(self, object_name: str) -> list[NDArray[np.float32]]: + """Return Nx3 numpy array [X, Y, Z] of the object's masked point cloud in target frame.""" + + camera_img_msg = self._get_image_message(self.camera_topic) + depth_msg = self.connector.receive_message(self.depth_topic).payload + camera_info = self._get_camera_info_message(self.camera_info_topic) + + fx, fy, cx, cy = self._get_intrinsic_from_camera_info(camera_info) + + gdino_future = self._call_gdino_node(camera_img_msg, object_name) + + logger = self.connector.node.get_logger() + try: + conversion_ratio_value = self.connector.node.get_parameter( + "conversion_ratio" + ).value # type: ignore[reportUnknownMemberType] + conversion_ratio: float + if isinstance(conversion_ratio_value, float): + conversion_ratio = conversion_ratio_value + else: + logger.error( # type: ignore[reportUnknownMemberType] + "Parameter conversion_ratio has wrong type. Using default 0.001" + ) + conversion_ratio = 0.001 + except (ParameterUninitializedException, ParameterNotDeclaredException): + logger.warning("Parameter conversion_ratio not found. Using default 0.001") # type: ignore[reportUnknownMemberType] + conversion_ratio = 0.001 + + gdino_resolved = get_future_result(gdino_future) + if gdino_resolved is None: + return [] + + gsam_future = self._call_gsam_node(camera_img_msg, gdino_resolved) + gsam_resolved = get_future_result(gsam_future) + if gsam_resolved is None or len(gsam_resolved.masks) == 0: + return [] + + depth = convert_ros_img_to_ndarray(depth_msg).astype(np.float32) + all_points: List[NDArray[np.float32]] = [] + for mask_msg in gsam_resolved.masks: + mask = cast(NDArray[np.uint8], convert_ros_img_to_ndarray(mask_msg)) + binary_mask = mask == 255 + masked_depth_image: NDArray[np.float32] = np.zeros_like( + depth, dtype=np.float32 + ) + masked_depth_image[binary_mask] = depth[binary_mask] + masked_depth_image = masked_depth_image * float(conversion_ratio) + + points_camera: NDArray[np.float32] = depth_to_point_cloud( + masked_depth_image, fx, fy, cx, cy + ) + if points_camera.size: + all_points.append(points_camera) + + if not all_points: + return [] + + points_target = [ + self._transform_points_source_to_target(points_source).astype(np.float32) + for points_source in all_points + ] + return points_target + + +class GrippingPointEstimator: + """Estimate gripping points from segmented point clouds using different strategies. + + This class operates on the output of `PointCloudFromSegmentation.run`, which is + a list of numpy arrays, one per segmented instance, each of shape (N, 3). + + Supported strategies: + - "centroid": centroid of all points + - "top_plane": centroid of points in the top-Z percentile (proxy for top plane) + - "biggest_plane": centroid of the most populated horizontal plane bin (RANSAC-free) + """ + + strategy: Literal["centroid", "top_plane", "biggest_plane"] + top_percentile: float + plane_bin_size_m: float + ransac_iterations: int + distance_threshold_m: float + min_points: int + + def __init__( + self, + *, + strategy: Literal["centroid", "top_plane", "biggest_plane"] = "centroid", + top_percentile: float = 0.05, + plane_bin_size_m: float = 0.01, + ransac_iterations: int = 200, + distance_threshold_m: float = 0.01, + min_points: int = 10, + ) -> None: + self.strategy = strategy + self.top_percentile = top_percentile + self.plane_bin_size_m = plane_bin_size_m + self.ransac_iterations = int(max(1, ransac_iterations)) + self.distance_threshold_m = float(max(1e-6, distance_threshold_m)) + self.min_points = min_points + + def _centroid(self, points: NDArray[np.float32]) -> Optional[NDArray[np.float32]]: + if points.size == 0: + return None + return points.mean(axis=0).astype(np.float32) + + def _top_plane_centroid( + self, points: NDArray[np.float32] + ) -> Optional[NDArray[np.float32]]: + if points.shape[0] < self.min_points: + return self._centroid(points) + z_vals = points[:, 2] + threshold = np.quantile(z_vals, 1.0 - self.top_percentile) + mask = z_vals >= threshold + top_points = points[mask] + if top_points.shape[0] == 0: + return self._centroid(points) + return top_points.mean(axis=0).astype(np.float32) + + def _biggest_plane_centroid( + self, points: NDArray[np.float32] + ) -> Optional[NDArray[np.float32]]: + # RANSAC plane detection: not restricted to horizontal planes + num_points = points.shape[0] + if num_points < self.min_points: + return self._centroid(points) + + best_inlier_count = 0 + best_inlier_mask: Optional[NDArray[np.bool_]] = None + + # Precompute for speed + pts64 = points.astype(np.float64, copy=False) + threshold = float(self.distance_threshold_m) + + rng = np.random.default_rng() + + for _ in range(self.ransac_iterations): + # Sample 3 unique points + idxs = rng.choice(num_points, size=3, replace=False) + p0, p1, p2 = pts64[idxs[0]], pts64[idxs[1]], pts64[idxs[2]] + v1 = p1 - p0 + v2 = p2 - p0 + normal = np.cross(v1, v2) + norm_len = np.linalg.norm(normal) + if norm_len < 1e-9: + continue # degenerate triplet + normal /= norm_len + # Distance from points to plane + # Plane eq: normal · (x - p0) = 0 -> distance = |normal · (x - p0)| + diffs = pts64 - p0 + dists = np.abs(diffs @ normal) + inliers = dists <= threshold + count = int(inliers.sum()) + if count > best_inlier_count: + best_inlier_count = count + best_inlier_mask = inliers + + if best_inlier_mask is None or best_inlier_count < self.min_points: + return self._centroid(points) + + inlier_points = points[best_inlier_mask] + if inlier_points.shape[0] == 0: + return self._centroid(points) + return inlier_points.mean(axis=0).astype(np.float32) + + def run( + self, segmented_point_clouds: list[NDArray[np.float32]] + ) -> list[NDArray[np.float32]]: + """Compute gripping points for each segmented point cloud. + + Parameters + ---------- + segmented_point_clouds: list of (N, 3) arrays in target frame. + + Returns + ------- + list of np.array points [[x, y, z], ...], one per input cloud. + """ + gripping_points: list[NDArray[np.float32]] = [] + + for pts in segmented_point_clouds: + if pts.size == 0: + continue + if self.strategy == "centroid": + gp = self._centroid(pts) + elif self.strategy == "top_plane": + gp = self._top_plane_centroid(pts) + elif self.strategy == "biggest_plane": + gp = self._biggest_plane_centroid(pts) + else: + gp = self._centroid(pts) + + if gp is not None: + gripping_points.append(gp.astype(np.float32)) + + return gripping_points + + +class PointCloudFilter: + """Filter segmented point clouds using various sklearn strategies. + + Strategies: + - "dbscan": keep the largest DBSCAN cluster (exclude label -1) + - "kmeans_largest_cluster": keep the largest KMeans cluster + - "isolation_forest": keep inliers (pred == 1) + - "lof": keep inliers (pred == 1) + """ + + strategy: Literal["dbscan", "kmeans_largest_cluster", "isolation_forest", "lof"] + min_points: int + # DBSCAN + dbscan_eps: float + dbscan_min_samples: int + # KMeans + kmeans_k: int + # Isolation Forest + if_max_samples: int | float | Literal["auto"] + if_contamination: float + # LOF + lof_n_neighbors: int + lof_contamination: float + + def __init__( + self, + *, + strategy: Literal[ + "dbscan", "kmeans_largest_cluster", "isolation_forest", "lof" + ] = "dbscan", + min_points: int = 20, + dbscan_eps: float = 0.02, + dbscan_min_samples: int = 10, + kmeans_k: int = 2, + if_max_samples: int | float | Literal["auto"] = "auto", + if_contamination: float = 0.05, + lof_n_neighbors: int = 20, + lof_contamination: float = 0.05, + ) -> None: + self.strategy = strategy + self.min_points = min_points + self.dbscan_eps = dbscan_eps + self.dbscan_min_samples = dbscan_min_samples + self.kmeans_k = kmeans_k + self.if_max_samples = if_max_samples + self.if_contamination = if_contamination + self.lof_n_neighbors = lof_n_neighbors + self.lof_contamination = lof_contamination + + def _filter_dbscan(self, pts: NDArray[np.float32]) -> NDArray[np.float32]: + from sklearn.cluster import DBSCAN # type: ignore[reportMissingImports] + + if pts.shape[0] < self.min_points: + return pts + db = DBSCAN(eps=self.dbscan_eps, min_samples=self.dbscan_min_samples) + labels = cast(NDArray[np.int64], db.fit_predict(pts)) # type: ignore[no-any-return] + if labels.size == 0: + return pts + valid = labels >= 0 + if not np.any(valid): + return pts + labels_valid = labels[valid] + unique_labels, counts = np.unique(labels_valid, return_counts=True) + dominant = unique_labels[np.argmax(counts)] + mask = labels == dominant + return pts[mask] + + def _filter_kmeans_largest(self, pts: NDArray[np.float32]) -> NDArray[np.float32]: + from sklearn.cluster import KMeans # type: ignore[reportMissingImports] + + if pts.shape[0] < max(self.min_points, self.kmeans_k): + return pts + kmeans = KMeans(n_clusters=self.kmeans_k, n_init="auto") + labels = cast(NDArray[np.int64], kmeans.fit_predict(pts)) # type: ignore[no-any-return] + unique_labels, counts = np.unique(labels, return_counts=True) + dominant = unique_labels[np.argmax(counts)] + mask = labels == dominant + return pts[mask] + + def _filter_isolation_forest(self, pts: NDArray[np.float32]) -> NDArray[np.float32]: + from sklearn.ensemble import ( + IsolationForest, # type: ignore[reportMissingImports] + ) + + if pts.shape[0] < self.min_points: + return pts + iso = IsolationForest( + max_samples=self.if_max_samples, + contamination=self.if_contamination, + random_state=42, + ) + pred = cast(NDArray[np.int64], iso.fit_predict(pts)) # type: ignore[no-any-return] # 1 inlier, -1 outlier + mask = pred == 1 + if not np.any(mask): + return pts + return pts[mask] + + def _filter_lof(self, pts: NDArray[np.float32]) -> NDArray[np.float32]: + from sklearn.neighbors import ( + LocalOutlierFactor, # type: ignore[reportMissingImports] + ) + + if pts.shape[0] < max(self.min_points, self.lof_n_neighbors + 1): + return pts + lof = LocalOutlierFactor( + n_neighbors=self.lof_n_neighbors, contamination=self.lof_contamination + ) + pred = cast(NDArray[np.int64], lof.fit_predict(pts)) # type: ignore[no-any-return] # 1 inlier, -1 outlier + mask = pred == 1 + if not np.any(mask): + return pts + return pts[mask] + + def run( + self, segmented_point_clouds: list[NDArray[np.float32]] + ) -> list[NDArray[np.float32]]: + filtered: list[NDArray[np.float32]] = [] + for pts in segmented_point_clouds: + if pts.size == 0: + continue + if self.strategy == "dbscan": + f = self._filter_dbscan(pts) + elif self.strategy == "kmeans_largest_cluster": + f = self._filter_kmeans_largest(pts) + elif self.strategy == "isolation_forest": + f = self._filter_isolation_forest(pts) + elif self.strategy == "lof": + f = self._filter_lof(pts) + else: + f = pts + filtered.append(f.astype(np.float32, copy=False)) + return filtered + + +import time + +from rai.communication.ros2 import ROS2Context + +ROS2Context() + + +def main(): + from rai.communication.ros2.connectors import ROS2Connector + + connector = ROS2Connector() + connector.node.declare_parameter("conversion_ratio", 1.0) + time.sleep(5) + est = GrippingPointEstimator( + strategy="biggest_plane", ransac_iterations=400, distance_threshold_m=0.008 + ) + + pc_gen = PointCloudFromSegmentation( + connector=connector, + camera_topic="/rgbd_camera/camera_image_color", + depth_topic="/rgbd_camera/camera_image_depth", + camera_info_topic="/rgbd_camera/camera_info", + source_frame="egofront_rgbd_camera_depth_optical_frame", + target_frame="egoarm_base_link", + ) + points_xyz = pc_gen.run( + object_name="box" + ) # ndarray of shape (N, 3) in target frame + print(points_xyz) + filt = PointCloudFilter(strategy="dbscan", dbscan_eps=0.02, dbscan_min_samples=10) + points_xyz = filt.run(points_xyz) # same list shape, each np.float32 (N,3) + grip_points = est.run(points_xyz) + + print(grip_points) + from geometry_msgs.msg import Point32 + from sensor_msgs.msg import PointCloud + + points = ( + np.concatenate(points_xyz, axis=0) + if points_xyz + else np.zeros((0, 3), dtype=np.float32) + ) + + msg = PointCloud() # type: ignore[reportUnknownArgumentType] + msg.header.frame_id = "egoarm_base_link" # type: ignore[reportUnknownMemberType] + msg.points = [Point32(x=float(p[0]), y=float(p[1]), z=float(p[2])) for p in points] # type: ignore[reportUnknownArgumentType] + pub = connector.node.create_publisher( # type: ignore[reportUnknownMemberType] + PointCloud, "/debug/get_grabbing_point_pointcloud", 10 + ) + from geometry_msgs.msg import Point, Pose, Vector3 + from std_msgs.msg import Header + from visualization_msgs.msg import Marker, MarkerArray + + marker_pub = connector.node.create_publisher( # type: ignore[reportUnknownMemberType] + MarkerArray, "/debug/get_grabbing_point_marker_array", 10 + ) + marker_array = MarkerArray() + header = Header() + header.frame_id = "egoarm_base_link" + # header.stamp = connector.node.get_clock().now().to_msg() + markers = [] + for i, p in enumerate(grip_points): + m = Marker() + m.header = header + m.type = Marker.SPHERE + m.action = Marker.ADD + m.pose = Pose(position=Point(x=float(p[0]), y=float(p[1]), z=float(p[2]))) + m.scale = Vector3(x=0.04, y=0.04, z=0.04) + m.id = i + m.color.r = 1.0 # type: ignore[reportUnknownMemberType] + m.color.g = 0.0 # type: ignore[reportUnknownMemberType] + m.color.b = 0.0 # type: ignore[reportUnknownMemberType] + m.color.a = 1.0 # type: ignore[reportUnknownMemberType] + + # m.ns = str(i) + + markers.append(m) # type: ignore[reportUnknownArgumentType] + marker_array.markers = markers + + while True: + connector.node.get_logger().info( # type: ignore[reportUnknownMemberType] + f"publishing pointcloud to /debug/get_grabbing_point_pointcloud: {len(msg.points)} points, mean: {np.array(points.mean(axis=0))}." + ) + + marker_pub.publish(marker_array) + pub.publish(msg) + time.sleep(0.1) + + +if __name__ == "__main__": + main() diff --git a/xd.py b/xd.py new file mode 100644 index 000000000..fec2dba68 --- /dev/null +++ b/xd.py @@ -0,0 +1,31 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from rai.agents import wait_for_shutdown +from rai.communication.ros2 import ROS2Context +from rai_open_set_vision.agents import GroundedSamAgent, GroundingDinoAgent + + +@ROS2Context() +def main(): + agent1 = GroundingDinoAgent() + agent2 = GroundedSamAgent() + agent1.run() + agent2.run() + wait_for_shutdown([agent1, agent2]) + + +if __name__ == "__main__": + main() From 910004032509308e3baec45cc7b87c8a5c8077d8 Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Tue, 2 Sep 2025 22:05:22 +0200 Subject: [PATCH 03/13] feat: GetGrippingPointTool --- .../rai/tools/ros2/detection/__init__.py | 21 +++++++ .../rai/tools/ros2/detection/tools.py | 63 +++++++++++++++++++ .../rai/tools/ros2/manipulation/custom.py | 2 + 3 files changed, 86 insertions(+) create mode 100644 src/rai_core/rai/tools/ros2/detection/__init__.py create mode 100644 src/rai_core/rai/tools/ros2/detection/tools.py diff --git a/src/rai_core/rai/tools/ros2/detection/__init__.py b/src/rai_core/rai/tools/ros2/detection/__init__.py new file mode 100644 index 000000000..b168382fa --- /dev/null +++ b/src/rai_core/rai/tools/ros2/detection/__init__.py @@ -0,0 +1,21 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .tools import ( + GetGrippingPointTool, +) + +__all__ = [ + "GetGrippingPointTool", +] diff --git a/src/rai_core/rai/tools/ros2/detection/tools.py b/src/rai_core/rai/tools/ros2/detection/tools.py new file mode 100644 index 000000000..f64a81190 --- /dev/null +++ b/src/rai_core/rai/tools/ros2/detection/tools.py @@ -0,0 +1,63 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Type + +from pydantic import BaseModel, Field + +from rai.tools.ros2.base import BaseROS2Tool +from rai.tools.ros2.detection.pcl import ( + GrippingPointEstimator, + PointCloudFilter, + PointCloudFromSegmentation, +) + + +class GetGrippingPointToolInput(BaseModel): + object_name: str = Field( + ..., + description="The name of the object to get the gripping point of e.g. 'box', 'apple', 'screwdriver'", + ) + + +# TODO(maciejmajek): Configuration system configurable with namespacing +class GetGrippingPointTool(BaseROS2Tool): + name: str = "get_gripping_point" + description: str = "Get gripping points for specified object/objects. Returns 3D coordinates where a robot gripper can grasp the object." + + point_cloud_from_segmentation: PointCloudFromSegmentation + gripping_point_estimator: GrippingPointEstimator + point_cloud_filter: PointCloudFilter + + args_schema: Type[GetGrippingPointToolInput] = GetGrippingPointToolInput + + def _run(self, object_name: str) -> str: + pcl = self.point_cloud_from_segmentation.run(object_name) + pcl = self.point_cloud_filter.run(pcl) + gps = self.gripping_point_estimator.run(pcl) + + message = "" + if len(gps) == 0: + message += f"No gripping point found for the object {object_name}\n" + elif len(gps) == 1: + message += f"The gripping point of the object {object_name} is {gps[0]}\n" + else: + message += f"Multiple gripping points found for the object {object_name}\n" + + for i, gp in enumerate(gps): + message += ( + f"The gripping point of the object {i + 1} {object_name} is {gp}\n" + ) + + return message diff --git a/src/rai_core/rai/tools/ros2/manipulation/custom.py b/src/rai_core/rai/tools/ros2/manipulation/custom.py index 6e0d9655c..b2d56e8bb 100644 --- a/src/rai_core/rai/tools/ros2/manipulation/custom.py +++ b/src/rai_core/rai/tools/ros2/manipulation/custom.py @@ -16,6 +16,7 @@ from typing import Literal, Type import numpy as np +from deprecated import deprecated from geometry_msgs.msg import Point, Pose, PoseStamped, Quaternion from pydantic import BaseModel, Field from tf2_geometry_msgs import do_transform_pose @@ -259,6 +260,7 @@ class GetObjectPositionsToolInput(BaseModel): ) +@deprecated("Use GetGrippingPointTool from rai.tools.ros2.detection instead") class GetObjectPositionsTool(BaseROS2Tool): name: str = "get_object_positions" description: str = ( From 54287ce8b47c7013fbf2c361ed313572364587c3 Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Tue, 2 Sep 2025 22:15:35 +0200 Subject: [PATCH 04/13] feat: timeout --- src/rai_core/rai/__init__.py | 2 + src/rai_core/rai/tools/__init__.py | 2 + .../rai/tools/ros2/detection/tools.py | 57 ++++--- src/rai_core/rai/tools/timeout.py | 139 ++++++++++++++++++ 4 files changed, 182 insertions(+), 18 deletions(-) create mode 100644 src/rai_core/rai/tools/timeout.py diff --git a/src/rai_core/rai/__init__.py b/src/rai_core/rai/__init__.py index b0d27851a..426515e4a 100644 --- a/src/rai_core/rai/__init__.py +++ b/src/rai_core/rai/__init__.py @@ -20,6 +20,7 @@ get_llm_model_direct, get_tracing_callbacks, ) +from .utils import timeout __all__ = [ "AgentRunner", @@ -29,4 +30,5 @@ "get_llm_model_config_and_vendor", "get_llm_model_direct", "get_tracing_callbacks", + "timeout", ] diff --git a/src/rai_core/rai/tools/__init__.py b/src/rai_core/rai/tools/__init__.py index ef74fc891..4b58b900b 100644 --- a/src/rai_core/rai/tools/__init__.py +++ b/src/rai_core/rai/tools/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .timeout import timeout, timeout_method diff --git a/src/rai_core/rai/tools/ros2/detection/tools.py b/src/rai_core/rai/tools/ros2/detection/tools.py index f64a81190..2ab49f895 100644 --- a/src/rai_core/rai/tools/ros2/detection/tools.py +++ b/src/rai_core/rai/tools/ros2/detection/tools.py @@ -16,6 +16,7 @@ from pydantic import BaseModel, Field +from rai.tools import timeout from rai.tools.ros2.base import BaseROS2Tool from rai.tools.ros2.detection.pcl import ( GrippingPointEstimator, @@ -40,24 +41,44 @@ class GetGrippingPointTool(BaseROS2Tool): gripping_point_estimator: GrippingPointEstimator point_cloud_filter: PointCloudFilter + timeout_sec: float = Field( + default=10.0, description="Timeout in seconds to get the gripping point" + ) + args_schema: Type[GetGrippingPointToolInput] = GetGrippingPointToolInput def _run(self, object_name: str) -> str: - pcl = self.point_cloud_from_segmentation.run(object_name) - pcl = self.point_cloud_filter.run(pcl) - gps = self.gripping_point_estimator.run(pcl) - - message = "" - if len(gps) == 0: - message += f"No gripping point found for the object {object_name}\n" - elif len(gps) == 1: - message += f"The gripping point of the object {object_name} is {gps[0]}\n" - else: - message += f"Multiple gripping points found for the object {object_name}\n" - - for i, gp in enumerate(gps): - message += ( - f"The gripping point of the object {i + 1} {object_name} is {gp}\n" - ) - - return message + @timeout( + self.timeout_sec, + f"Gripping point detection for object '{object_name}' exceeded {self.timeout_sec} seconds", + ) + def _run_with_timeout(): + pcl = self.point_cloud_from_segmentation.run(object_name) + pcl = self.point_cloud_filter.run(pcl) + gps = self.gripping_point_estimator.run(pcl) + + message = "" + if len(gps) == 0: + message += f"No gripping point found for the object {object_name}\n" + elif len(gps) == 1: + message += ( + f"The gripping point of the object {object_name} is {gps[0]}\n" + ) + else: + message += ( + f"Multiple gripping points found for the object {object_name}\n" + ) + + for i, gp in enumerate(gps): + message += ( + f"The gripping point of the object {i + 1} {object_name} is {gp}\n" + ) + + return message + + try: + return _run_with_timeout() + except Exception as e: + if "timed out" in str(e).lower(): + return f"Timeout: Gripping point detection for object '{object_name}' exceeded {self.timeout_sec} seconds" + raise diff --git a/src/rai_core/rai/tools/timeout.py b/src/rai_core/rai/tools/timeout.py new file mode 100644 index 000000000..3f660c7e9 --- /dev/null +++ b/src/rai_core/rai/tools/timeout.py @@ -0,0 +1,139 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import signal +from functools import wraps +from typing import Any, Callable, TypeVar + +F = TypeVar("F", bound=Callable[..., Any]) + + +class TimeoutError(Exception): + """Raised when an operation times out.""" + + pass + + +def timeout(seconds: float, timeout_message: str = None) -> Callable[[F], F]: + """ + Decorator that adds timeout functionality to a function. + + Parameters + ---------- + seconds : float + Timeout duration in seconds + timeout_message : str, optional + Custom timeout message. If not provided, a default message will be used. + + Returns + ------- + Callable + Decorated function with timeout functionality + + Raises + ------ + TimeoutError + When the decorated function exceeds the specified timeout + + Examples + -------- + >>> @timeout(5.0, "Operation timed out") + ... def slow_operation(): + ... import time + ... time.sleep(10) + ... return "Done" + >>> + >>> try: + ... result = slow_operation() + ... except TimeoutError as e: + ... print(f"Timeout: {e}") + """ + + def decorator(func: F) -> F: + @wraps(func) + def wrapper(*args, **kwargs): + def timeout_handler(signum, frame): + message = ( + timeout_message + or f"Function '{func.__name__}' timed out after {seconds} seconds" + ) + raise TimeoutError(message) + + # Set up timeout + old_handler = signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(int(seconds)) + + try: + return func(*args, **kwargs) + finally: + # Clean up timeout + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) + + return wrapper + + return decorator + + +def timeout_method(seconds: float, timeout_message: str = None) -> Callable[[F], F]: + """ + Decorator that adds timeout functionality to a method. + Similar to timeout but designed for class methods. + + Parameters + ---------- + seconds : float + Timeout duration in seconds + timeout_message : str, optional + Custom timeout message. If not provided, a default message will be used. + + Returns + ------- + Callable + Decorated method with timeout functionality + + Examples + -------- + >>> class MyClass: + ... @timeout_method(3.0, "Method timed out") + ... def slow_method(self): + ... import time + ... time.sleep(5) + ... return "Done" + """ + + def decorator(func: F) -> F: + @wraps(func) + def wrapper(self, *args, **kwargs): + def timeout_handler(signum, frame): + message = ( + timeout_message + or f"Method '{func.__name__}' of {self.__class__.__name__} timed out after {seconds} seconds" + ) + raise TimeoutError(message) + + # Set up timeout + old_handler = signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(int(seconds)) + + try: + return func(self, *args, **kwargs) + finally: + # Clean up timeout + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) + + return wrapper + + return decorator From 1ec5e982bd38e28619ebf08f8afb56296b0e1f3e Mon Sep 17 00:00:00 2001 From: Julia Jia Date: Sat, 20 Sep 2025 15:19:45 -0700 Subject: [PATCH 05/13] Refactor GetGrippingPointTool and introduce unit tests and manual tests --- pyproject.toml | 3 +- src/rai_core/rai/__init__.py | 2 +- src/rai_core/rai/tools/ros2/detection/pcl.py | 165 ++++----- .../rai/tools/ros2/detection/tools.py | 37 +- .../tools/segmentation_tools.py | 35 +- tests/tools/ros2/test_detection_tools.py | 218 +++++++++++ tests/tools/ros2/test_gripping_points.py | 346 ++++++++++++++++++ 7 files changed, 687 insertions(+), 119 deletions(-) create mode 100644 tests/tools/ros2/test_detection_tools.py create mode 100644 tests/tools/ros2/test_gripping_points.py diff --git a/pyproject.toml b/pyproject.toml index 5ba7eb2ed..2673f89a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,8 @@ build-backend = "poetry.core.masonry.api" markers = [ "billable: marks test as billable (deselect with '-m \"not billable\"')", "ci_only: marks test as cli only (deselect with '-m \"not ci_only\"')", + "manual: marks tests as manual (may require demo app to be running)", ] -addopts = "-m 'not billable and not ci_only' --ignore=src" +addopts = "-m 'not billable and not ci_only and not manual' --ignore=src" log_cli = true log_cli_level = "INFO" diff --git a/src/rai_core/rai/__init__.py b/src/rai_core/rai/__init__.py index 426515e4a..931aed9ec 100644 --- a/src/rai_core/rai/__init__.py +++ b/src/rai_core/rai/__init__.py @@ -20,7 +20,7 @@ get_llm_model_direct, get_tracing_callbacks, ) -from .utils import timeout +from .tools import timeout __all__ = [ "AgentRunner", diff --git a/src/rai_core/rai/tools/ros2/detection/pcl.py b/src/rai_core/rai/tools/ros2/detection/pcl.py index c9f59d16f..10dd2884b 100644 --- a/src/rai_core/rai/tools/ros2/detection/pcl.py +++ b/src/rai_core/rai/tools/ros2/detection/pcl.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import time from typing import List, Literal, Optional, cast import numpy as np @@ -46,6 +47,77 @@ def depth_to_point_cloud( return points.astype(np.float32, copy=False) +def _publish_gripping_point_debug_data( + connector: ROS2Connector, + obj_points_xyz: NDArray[np.float32], + gripping_points_xyz: list[NDArray[np.float32]], + base_frame_id: str = "egoarm_base_link", + publish_duration: float = 10.0, +) -> None: + """Publish the gripping point debug data for visualization in RVIZ via point cloud and marker array. + + Args: + connector: The ROS2 connector. + obj_points_xyz: The list of objects found in the image. + gripping_points_xyz: The list of gripping points in the base frame. + base_frame_id: The base frame id. + publish_duration: Duration in seconds to publish the data (default: 10.0). + """ + + from geometry_msgs.msg import Point, Point32, Pose, Vector3 + from sensor_msgs.msg import PointCloud + from std_msgs.msg import Header + from visualization_msgs.msg import Marker, MarkerArray + + points = ( + np.concatenate(obj_points_xyz, axis=0) + if obj_points_xyz + else np.zeros((0, 3), dtype=np.float32) + ) + + msg = PointCloud() # type: ignore[reportUnknownArgumentType] + msg.header.frame_id = base_frame_id # type: ignore[reportUnknownMemberType] + msg.points = [Point32(x=float(p[0]), y=float(p[1]), z=float(p[2])) for p in points] # type: ignore[reportUnknownArgumentType] + pub = connector.node.create_publisher( # type: ignore[reportUnknownMemberType] + PointCloud, "/debug_gripping_points_pointcloud", 10 + ) + + marker_pub = connector.node.create_publisher( # type: ignore[reportUnknownMemberType] + MarkerArray, "/debug_gripping_points_markerarray", 10 + ) + marker_array = MarkerArray() + header = Header() + header.frame_id = base_frame_id + header.stamp = connector.node.get_clock().now().to_msg() + markers = [] + for i, p in enumerate(gripping_points_xyz): + m = Marker() + m.header = header + m.type = Marker.SPHERE + m.action = Marker.ADD + m.pose = Pose(position=Point(x=float(p[0]), y=float(p[1]), z=float(p[2]))) + m.scale = Vector3(x=0.04, y=0.04, z=0.04) + m.id = i + m.color.r = 1.0 # type: ignore[reportUnknownMemberType] + m.color.g = 0.0 # type: ignore[reportUnknownMemberType] + m.color.b = 0.0 # type: ignore[reportUnknownMemberType] + m.color.a = 1.0 # type: ignore[reportUnknownMemberType] + + # m.ns = str(i) + + markers.append(m) # type: ignore[reportUnknownArgumentType] + marker_array.markers = markers + + start_time = time.time() + publish_rate = 10.0 # Hz + sleep_duration = 1.0 / publish_rate + + while time.time() - start_time < publish_duration: + marker_pub.publish(marker_array) + pub.publish(msg) + time.sleep(sleep_duration) + + class PointCloudFromSegmentation: """Generate a masked point cloud for an object and transform it to a target frame. @@ -511,96 +583,3 @@ def run( f = pts filtered.append(f.astype(np.float32, copy=False)) return filtered - - -import time - -from rai.communication.ros2 import ROS2Context - -ROS2Context() - - -def main(): - from rai.communication.ros2.connectors import ROS2Connector - - connector = ROS2Connector() - connector.node.declare_parameter("conversion_ratio", 1.0) - time.sleep(5) - est = GrippingPointEstimator( - strategy="biggest_plane", ransac_iterations=400, distance_threshold_m=0.008 - ) - - pc_gen = PointCloudFromSegmentation( - connector=connector, - camera_topic="/rgbd_camera/camera_image_color", - depth_topic="/rgbd_camera/camera_image_depth", - camera_info_topic="/rgbd_camera/camera_info", - source_frame="egofront_rgbd_camera_depth_optical_frame", - target_frame="egoarm_base_link", - ) - points_xyz = pc_gen.run( - object_name="box" - ) # ndarray of shape (N, 3) in target frame - print(points_xyz) - filt = PointCloudFilter(strategy="dbscan", dbscan_eps=0.02, dbscan_min_samples=10) - points_xyz = filt.run(points_xyz) # same list shape, each np.float32 (N,3) - grip_points = est.run(points_xyz) - - print(grip_points) - from geometry_msgs.msg import Point32 - from sensor_msgs.msg import PointCloud - - points = ( - np.concatenate(points_xyz, axis=0) - if points_xyz - else np.zeros((0, 3), dtype=np.float32) - ) - - msg = PointCloud() # type: ignore[reportUnknownArgumentType] - msg.header.frame_id = "egoarm_base_link" # type: ignore[reportUnknownMemberType] - msg.points = [Point32(x=float(p[0]), y=float(p[1]), z=float(p[2])) for p in points] # type: ignore[reportUnknownArgumentType] - pub = connector.node.create_publisher( # type: ignore[reportUnknownMemberType] - PointCloud, "/debug/get_grabbing_point_pointcloud", 10 - ) - from geometry_msgs.msg import Point, Pose, Vector3 - from std_msgs.msg import Header - from visualization_msgs.msg import Marker, MarkerArray - - marker_pub = connector.node.create_publisher( # type: ignore[reportUnknownMemberType] - MarkerArray, "/debug/get_grabbing_point_marker_array", 10 - ) - marker_array = MarkerArray() - header = Header() - header.frame_id = "egoarm_base_link" - # header.stamp = connector.node.get_clock().now().to_msg() - markers = [] - for i, p in enumerate(grip_points): - m = Marker() - m.header = header - m.type = Marker.SPHERE - m.action = Marker.ADD - m.pose = Pose(position=Point(x=float(p[0]), y=float(p[1]), z=float(p[2]))) - m.scale = Vector3(x=0.04, y=0.04, z=0.04) - m.id = i - m.color.r = 1.0 # type: ignore[reportUnknownMemberType] - m.color.g = 0.0 # type: ignore[reportUnknownMemberType] - m.color.b = 0.0 # type: ignore[reportUnknownMemberType] - m.color.a = 1.0 # type: ignore[reportUnknownMemberType] - - # m.ns = str(i) - - markers.append(m) # type: ignore[reportUnknownArgumentType] - marker_array.markers = markers - - while True: - connector.node.get_logger().info( # type: ignore[reportUnknownMemberType] - f"publishing pointcloud to /debug/get_grabbing_point_pointcloud: {len(msg.points)} points, mean: {np.array(points.mean(axis=0))}." - ) - - marker_pub.publish(marker_array) - pub.publish(msg) - time.sleep(0.1) - - -if __name__ == "__main__": - main() diff --git a/src/rai_core/rai/tools/ros2/detection/tools.py b/src/rai_core/rai/tools/ros2/detection/tools.py index 2ab49f895..c35b0672c 100644 --- a/src/rai_core/rai/tools/ros2/detection/tools.py +++ b/src/rai_core/rai/tools/ros2/detection/tools.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Type +from typing import Any, Optional, Type from pydantic import BaseModel, Field -from rai.tools import timeout from rai.tools.ros2.base import BaseROS2Tool from rai.tools.ros2.detection.pcl import ( GrippingPointEstimator, @@ -33,27 +32,51 @@ class GetGrippingPointToolInput(BaseModel): # TODO(maciejmajek): Configuration system configurable with namespacing +# TODO(juliajia): Question for Maciej: for comments above on configuration system with namespacing, can you provide an use case for this? class GetGrippingPointTool(BaseROS2Tool): name: str = "get_gripping_point" description: str = "Get gripping points for specified object/objects. Returns 3D coordinates where a robot gripper can grasp the object." - point_cloud_from_segmentation: PointCloudFromSegmentation + target_frame: str + source_frame: str + camera_topic: str # rgb camera topic + depth_topic: str + camera_info_topic: str # rgb camera info topic + gripping_point_estimator: GrippingPointEstimator point_cloud_filter: PointCloudFilter + # Auto-initialized in model_post_init + point_cloud_from_segmentation: Optional[PointCloudFromSegmentation] = None + timeout_sec: float = Field( default=10.0, description="Timeout in seconds to get the gripping point" ) args_schema: Type[GetGrippingPointToolInput] = GetGrippingPointToolInput - def _run(self, object_name: str) -> str: - @timeout( - self.timeout_sec, - f"Gripping point detection for object '{object_name}' exceeded {self.timeout_sec} seconds", + def model_post_init(self, __context: Any) -> None: + """Initialize PointCloudFromSegmentation with the provided camera parameters.""" + self.point_cloud_from_segmentation = PointCloudFromSegmentation( + connector=self.connector, + camera_topic=self.camera_topic, + depth_topic=self.depth_topic, + camera_info_topic=self.camera_info_topic, + source_frame=self.source_frame, + target_frame=self.target_frame, ) + + def _run(self, object_name: str) -> str: + # this will be not work in agent scenario because signal need to be run in main thread, comment out for now + # @timeout( + # self.timeout_sec, + # f"Gripping point detection for object '{object_name}' exceeded {self.timeout_sec} seconds", + # ) def _run_with_timeout(): pcl = self.point_cloud_from_segmentation.run(object_name) + if len(pcl) == 0: + return f"No {object_name}s detected." + pcl = self.point_cloud_filter.run(pcl) gps = self.gripping_point_estimator.run(pcl) diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py index a395d986c..b168943c0 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py @@ -293,23 +293,24 @@ def _process_mask( points = pcd # publish resulting pointcloud - import time - - from geometry_msgs.msg import Point32 - from sensor_msgs.msg import PointCloud - - msg = PointCloud() - msg.header.frame_id = "egofront_rgbd_camera_depth_optical_frame" - msg.points = [Point32(x=p[0], y=p[1], z=p[2]) for p in points] - pub = self.connector.node.create_publisher( - PointCloud, "/debug/get_grabbing_point_pointcloud", 10 - ) - while True: - self.connector.node.get_logger().info( - f"publishing pointcloud to /debug/get_grabbing_point_pointcloud: {len(msg.points)} points, mean: {np.array(points.mean(axis=0))}." - ) - pub.publish(msg) - time.sleep(0.1) + # TODO(juliajia): remove this after debugging + # import time + + # from geometry_msgs.msg import Point32 + # from sensor_msgs.msg import PointCloud + + # msg = PointCloud() + # msg.header.frame_id = "egofront_rgbd_camera_depth_optical_frame" + # msg.points = [Point32(x=p[0], y=p[1], z=p[2]) for p in points] + # pub = self.connector.node.create_publisher( + # PointCloud, "/debug/get_grabbing_point_pointcloud", 10 + # ) + # while True: + # self.connector.node.get_logger().info( + # f"publishing pointcloud to /debug/get_grabbing_point_pointcloud: {len(msg.points)} points, mean: {np.array(points.mean(axis=0))}." + # ) + # pub.publish(msg) + # time.sleep(0.1) # https://github.com/ycheng517/tabletop-handybot/blob/6d401e577e41ea86529d091b406fbfc936f37a8d/tabletop_handybot/tabletop_handybot/tabletop_handybot_node.py#L413-L424 grasp_z = points[:, 2].max() diff --git a/tests/tools/ros2/test_detection_tools.py b/tests/tools/ros2/test_detection_tools.py new file mode 100644 index 000000000..b2ecb6e44 --- /dev/null +++ b/tests/tools/ros2/test_detection_tools.py @@ -0,0 +1,218 @@ +# Copyright (C) 2025 Julia Jia +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +try: + import rclpy # noqa: F401 + + _ = rclpy # noqa: F841 +except ImportError: + pytest.skip("ROS2 is not installed", allow_module_level=True) + +from unittest.mock import Mock + +import numpy as np +from rai.communication.ros2.connectors import ROS2Connector +from rai.tools.ros2.detection import GetGrippingPointTool +from rai.tools.ros2.detection.pcl import ( + GrippingPointEstimator, + PointCloudFilter, + PointCloudFromSegmentation, + depth_to_point_cloud, +) + + +def test_depth_to_point_cloud(): + """Test depth image to point cloud conversion algorithm.""" + # Create a simple 2x2 depth image with known values + depth_image = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + + # Camera intrinsics + fx, fy, cx, cy = 100.0, 100.0, 1.0, 1.0 + + # Convert to point cloud + points = depth_to_point_cloud(depth_image, fx, fy, cx, cy) + + # Should have 4 points (2x2 image) + assert points.shape[0] == 4 + assert points.shape[1] == 3 # X, Y, Z coordinates + + # Check that all Z values match the depth image + expected_z_values = [1.0, 2.0, 3.0, 4.0] + actual_z_values = sorted(points[:, 2].tolist()) + np.testing.assert_array_almost_equal(actual_z_values, expected_z_values) + + # Verify no points with zero depth are included + zero_depth = np.zeros((2, 2), dtype=np.float32) + points_zero = depth_to_point_cloud(zero_depth, fx, fy, cx, cy) + assert points_zero.shape[0] == 0 + + +def test_gripping_point_estimator(): + """Test gripping point estimation strategies.""" + # Create test point cloud data - a simple box shape + points1 = np.array( + [ + [1.0, 1.0, 1.0], + [1.0, 1.0, 2.0], + [2.0, 1.0, 1.0], + [2.0, 1.0, 2.0], + [1.0, 2.0, 1.0], + [1.0, 2.0, 2.0], + [2.0, 2.0, 1.0], + [2.0, 2.0, 2.0], + ], + dtype=np.float32, + ) + + points2 = np.array( + [ + [5.0, 5.0, 5.0], + [5.0, 5.0, 6.0], + [6.0, 5.0, 5.0], + [6.0, 5.0, 6.0], + ], + dtype=np.float32, + ) + + segmented_clouds = [points1, points2] + + # Test centroid strategy + estimator = GrippingPointEstimator(strategy="centroid") + grip_points = estimator.run(segmented_clouds) + + assert len(grip_points) == 2 + # Check centroid of first cloud + expected_centroid1 = np.array([1.5, 1.5, 1.5], dtype=np.float32) + np.testing.assert_array_almost_equal(grip_points[0], expected_centroid1) + + # Test top_plane strategy + estimator_top = GrippingPointEstimator(strategy="top_plane", top_percentile=0.5) + grip_points_top = estimator_top.run(segmented_clouds) + + assert len(grip_points_top) == 2 + # Top plane should have higher Z values + assert grip_points_top[0][2] >= grip_points[0][2] + + # Test with empty point cloud + empty_clouds = [np.array([]).reshape(0, 3).astype(np.float32)] + grip_points_empty = estimator.run(empty_clouds) + assert len(grip_points_empty) == 0 + + +def test_point_cloud_filter(): + """Test point cloud filtering strategies.""" + # Create test data with noise points + main_cluster = np.random.normal([0, 0, 0], 0.1, (50, 3)).astype(np.float32) + noise_points = np.random.normal([5, 5, 5], 0.1, (5, 3)).astype(np.float32) + noisy_cloud = np.vstack([main_cluster, noise_points]) + + clouds = [noisy_cloud] + + # Test DBSCAN filtering + filter_dbscan = PointCloudFilter( + strategy="dbscan", dbscan_eps=0.5, dbscan_min_samples=5 + ) + filtered_dbscan = filter_dbscan.run(clouds) + + assert len(filtered_dbscan) == 1 + # Should remove most noise points + assert filtered_dbscan[0].shape[0] < noisy_cloud.shape[0] + assert filtered_dbscan[0].shape[0] >= 40 # Should keep most of main cluster + + # Test with too few points (should return original) + small_cloud = np.array([[1, 1, 1], [2, 2, 2]], dtype=np.float32) + filter_small = PointCloudFilter(strategy="dbscan", min_points=20) + filtered_small = filter_small.run([small_cloud]) + + assert len(filtered_small) == 1 + np.testing.assert_array_equal(filtered_small[0], small_cloud) + + # Test kmeans_largest_cluster strategy + filter_kmeans = PointCloudFilter(strategy="kmeans_largest_cluster", kmeans_k=2) + filtered_kmeans = filter_kmeans.run(clouds) + + assert len(filtered_kmeans) == 1 + assert filtered_kmeans[0].shape[0] > 0 + + +def test_get_gripping_point_tool_timeout(): + """Test GetGrippingPointTool timeout behavior.""" + # Mock the connector and components + mock_connector = Mock(spec=ROS2Connector) + + # Create mock components that will simulate timeout + mock_pcl_gen = Mock(spec=PointCloudFromSegmentation) + mock_pcl_gen.run.side_effect = lambda x: [] # Return empty to simulate no detection + + mock_filter = Mock(spec=PointCloudFilter) + mock_filter.run.return_value = [] + + mock_estimator = Mock(spec=GrippingPointEstimator) + mock_estimator.run.return_value = [] + + # Create tool with short timeout + tool = GetGrippingPointTool( + connector=mock_connector, + point_cloud_from_segmentation=mock_pcl_gen, + point_cloud_filter=mock_filter, + gripping_point_estimator=mock_estimator, + timeout_sec=0.1, + ) + + # Test successful run with no gripping points found + result = tool._run("test_object") + assert "No gripping point found" in result + assert "test_object" in result + + # Test with mock that simulates found gripping points + mock_estimator.run.return_value = [np.array([1.0, 2.0, 3.0], dtype=np.float32)] + result = tool._run("test_object") + assert "gripping point of the object test_object is" in result + assert "[1. 2. 3.]" in result + + # Test with multiple gripping points + mock_estimator.run.return_value = [ + np.array([1.0, 2.0, 3.0], dtype=np.float32), + np.array([4.0, 5.0, 6.0], dtype=np.float32), + ] + result = tool._run("test_object") + assert "Multiple gripping points found" in result + + +def test_get_gripping_point_tool_validation(): + """Test GetGrippingPointTool input validation.""" + mock_connector = Mock(spec=ROS2Connector) + mock_pcl_gen = Mock(spec=PointCloudFromSegmentation) + mock_filter = Mock(spec=PointCloudFilter) + mock_estimator = Mock(spec=GrippingPointEstimator) + + # Test tool creation + tool = GetGrippingPointTool( + connector=mock_connector, + point_cloud_from_segmentation=mock_pcl_gen, + point_cloud_filter=mock_filter, + gripping_point_estimator=mock_estimator, + ) + + # Verify tool properties + assert tool.name == "get_gripping_point" + assert "gripping points" in tool.description + assert tool.timeout_sec == 10.0 # default value + + # Test args schema + from rai.tools.ros2.detection.tools import GetGrippingPointToolInput + + assert tool.args_schema == GetGrippingPointToolInput diff --git a/tests/tools/ros2/test_gripping_points.py b/tests/tools/ros2/test_gripping_points.py new file mode 100644 index 000000000..64c959d5d --- /dev/null +++ b/tests/tools/ros2/test_gripping_points.py @@ -0,0 +1,346 @@ +# Copyright (C) 2025 Julia Jia +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 +""" +Manual test for GetGrippingPointTool with various demo scenarios. Each test: +- Finds gripping points of specified object in the target frame. +- Publishes debug data for visualization. +- Saves annotated image of the gripping points. + +The demo app and rivz2 need to be started before running the test. The test will fail if the gripping points are not found. + +Usage: +pytest tests/tools/ros2/test_gripping_points.py::test_gripping_points_manipulation_demo -m "" -s -v +""" + +import cv2 +import numpy as np +import pytest +import rclpy +from cv_bridge import CvBridge +from rai.communication.ros2 import wait_for_ros2_services, wait_for_ros2_topics +from rai.communication.ros2.connectors import ROS2Connector +from rai.tools.ros2.detection import GetGrippingPointTool +from rai.tools.ros2.detection.pcl import ( + GrippingPointEstimator, + PointCloudFilter, + _publish_gripping_point_debug_data, +) + +# Test configurations +TEST_CONFIGS = { + "manipulation-demo": { + "services": ["/grounded_sam_segment", "/grounding_dino_classify"], + "topics": { + "color_image": "/color_image5", + "depth_image": "/depth_image5", + "camera_info": "/color_camera_info5", + }, + "frames": {"target": "panda_link0", "source": "RGBDCamera5"}, + "algorithms": { + "filter": { + "strategy": "dbscan", + "dbscan_eps": 0.02, + "dbscan_min_samples": 5, + }, + "estimator": {"strategy": "centroid"}, + }, + }, + "maciej-test-demo": { + "services": ["/grounded_sam_segment", "/grounding_dino_classify"], + "topics": { + "color_image": "/rgbd_camera/camera_image_color", + "depth_image": "/rgbd_camera/camera_image_depth", + "camera_info": "/rgbd_camera/camera_info", + }, + "frames": { + "target": "egoarm_base_link", + "source": "egofront_rgbd_camera_depth_optical_frame", + }, + "algorithms": { + "filter": { + "strategy": "dbscan", + "dbscan_eps": 0.02, + "dbscan_min_samples": 10, + }, + "estimator": { + "strategy": "biggest_plane", + "ransac_iterations": 400, + "distance_threshold_m": 0.008, + }, + }, + }, + "dummy-example-with-default-algorithm-parameters": { + "services": ["/grounded_sam_segment", "/grounding_dino_classify"], + "topics": { + "color_image": "/color_image5", + "depth_image": "/depth_image5", + "camera_info": "/color_camera_info5", + }, + "frames": {"target": "panda_link0", "source": "RGBDCamera5"}, + "algorithms": { + "filter": { + "strategy": "dbscan", + "min_points": 100, + "dbscan_eps": 0.02, + "dbscan_min_samples": 5, + "kmeans_k": 3, + "if_max_samples": 100, + "if_contamination": 0.1, + "lof_n_neighbors": 20, + "lof_contamination": 0.1, + }, + "estimator": { + "strategy": "centroid", + "top_percentile": 0.8, + "plane_bin_size_m": 0.01, + "ransac_iterations": 100, + "distance_threshold_m": 0.01, + "min_points": 10, + }, + }, + }, +} + + +def draw_points_on_image(image_msg, points, camera_info): + """Draw points on the camera image.""" + # Convert ROS image to OpenCV + bridge = CvBridge() + cv_image = bridge.imgmsg_to_cv2(image_msg, "bgr8") + + # Get camera intrinsics + fx = camera_info.k[0] + fy = camera_info.k[4] + cx = camera_info.k[2] + cy = camera_info.k[5] + + # Project 3D points to 2D + for i, point in enumerate(points): + x, y, z = point[0], point[1], point[2] + + # Check if point is in front of camera + if z <= 0: + continue + + # Project to pixel coordinates + u = int((x * fx / z) + cx) + v = int((y * fy / z) + cy) + + # Check if point is within image bounds + if 0 <= u < cv_image.shape[1] and 0 <= v < cv_image.shape[0]: + # Draw circle and label + cv2.circle(cv_image, (u, v), 10, (0, 0, 255), -1) # Red filled circle + cv2.circle(cv_image, (u, v), 15, (0, 255, 0), 2) # Green outline + cv2.putText( + cv_image, + f"GP{i + 1}", + (u + 20, v - 10), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (255, 255, 255), + 2, + ) + + return cv_image + + +def extract_gripping_points(result: str) -> list[np.ndarray]: + """Extract gripping points from the result.""" + gripping_points = [] + lines = result.split("\n") + for line in lines: + if "gripping point" in line and "is [" in line: + # Extract coordinates from line like "is [0.39972728 0.16179778 0.04179673]" + start = line.find("[") + 1 + end = line.find("]") + if start > 0 and end > start: + coords_str = line[start:end] + coords = [float(x) for x in coords_str.split()] + gripping_points.append(np.array(coords)) + return gripping_points + + +def transform_points_to_target_frame(connector, points, source_frame, target_frame): + """Transform points from source frame(e.g. camera frame) to target frame(e.g. robot frame).""" + try: + # Get transform from target frame to source frame + transform = connector.get_transform(source_frame, target_frame) + + # Extract translation and rotation + t = transform.transform.translation + r = transform.transform.rotation + + # Convert quaternion to rotation matrix + qw, qx, qy, qz = float(r.w), float(r.x), float(r.y), float(r.z) + + # Quaternion to rotation matrix conversion + rotation_matrix = np.array( + [ + [ + 1 - 2 * (qy * qy + qz * qz), + 2 * (qx * qy - qw * qz), + 2 * (qx * qz + qw * qy), + ], + [ + 2 * (qx * qy + qw * qz), + 1 - 2 * (qx * qx + qz * qz), + 2 * (qy * qz - qw * qx), + ], + [ + 2 * (qx * qz - qw * qy), + 2 * (qy * qz + qw * qx), + 1 - 2 * (qx * qx + qy * qy), + ], + ] + ) + + translation = np.array([float(t.x), float(t.y), float(t.z)]) + + # Transform points: R * point + translation (forward transform) + transformed_points = [] + for point in points: + # Apply forward transform: R * point + translation + transformed_point = rotation_matrix @ point + translation + transformed_points.append(transformed_point) + + return transformed_points + except Exception as e: + print(f"Transform error: {e}") + return points + + +def save_annotated_image( + connector, gripping_points, config, filename: str = "gripping_points_annotated.jpg" +): + camera_frame_points = transform_points_to_target_frame( + connector, + gripping_points, + config["frames"]["source"], + config["frames"]["target"], + ) + + # Get current camera image and draw points + image_msg = connector.receive_message(config["topics"]["color_image"]).payload + camera_info_msg = connector.receive_message(config["topics"]["camera_info"]).payload + + # Draw gripping points on image + annotated_image = draw_points_on_image( + image_msg, camera_frame_points, camera_info_msg + ) + + cv2.imwrite(filename, annotated_image) + + +def main(config_name: str = "manipulation-demo", test_object: str = "cube"): + """Enhanced test with visualization and better error handling.""" + + # Get test configuration + config = TEST_CONFIGS[config_name] + + print(f"Config: {config_name}") + + # Initialize ROS2 + rclpy.init() + connector = ROS2Connector(executor_type="single_threaded") + + try: + # Wait for required services and topics + print("Waiting for ROS2 services and topics...") + wait_for_ros2_services(connector, config["services"]) + wait_for_ros2_topics(connector, list(config["topics"].values())) + print("✅ All services and topics available") + + # Set up node parameters + node = connector.node + node.declare_parameter("conversion_ratio", 1.0) + + # Create tool components + algo_config = config["algorithms"] + + # Create gripping estimator with strategy-specific parameters + estimator_config = algo_config["estimator"] + gripping_estimator = GrippingPointEstimator(**estimator_config) + + # Create point cloud filter + filter_config = algo_config["filter"] + point_cloud_filter = PointCloudFilter(**filter_config) + + # Create the tool + gripping_tool = GetGrippingPointTool( + connector=connector, + target_frame=config["frames"]["target"], + source_frame=config["frames"]["source"], + camera_topic=config["topics"]["color_image"], + depth_topic=config["topics"]["depth_image"], + camera_info_topic=config["topics"]["camera_info"], + gripping_point_estimator=gripping_estimator, + point_cloud_filter=point_cloud_filter, + ) + + # Test the tool directly + print(f"\nTesting GetGrippingPointTool with object '{test_object}'") + + result = gripping_tool._run(test_object) + gripping_points = extract_gripping_points(result) + print(f"\nFound {len(gripping_points)} gripping points in target frame:") + + for i, gp in enumerate(gripping_points): + print(f" GP{i + 1}: [{gp[0]:.3f}, {gp[1]:.3f}, {gp[2]:.3f}]") + + if gripping_points: + # Call the function in pcl.py to publish the gripping point for visualization + segmented_clouds = gripping_tool.point_cloud_from_segmentation.run( + test_object + ) + print( + "\nPublishing debug data to /debug_gripping_points_pointcloud and /debug_gripping_points_markerarray" + ) + _publish_gripping_point_debug_data( + connector, segmented_clouds, gripping_points, config["frames"]["target"] + ) + print("✅ Debug data published") + + annotated_image_path = f"{test_object}_gripping_points.jpg" + save_annotated_image( + connector, gripping_points, config, annotated_image_path + ) + print(f"✅ Saved annotated image as '{annotated_image_path}'") + + else: + print("❌ No gripping points found") + + except Exception as e: + print(f"❌ Setup failed: {e}") + import traceback + + traceback.print_exc() + + finally: + if hasattr(connector, "executor"): + connector.executor.shutdown() + connector.shutdown() + + +@pytest.mark.manual +def test_gripping_points_manipulation_demo(): + """Manual test requiring manipulation-demo app to be started.""" + main("manipulation-demo", "apple") + + +@pytest.mark.manual +def test_gripping_points_maciej_demo(): + """Manual test requiring demo app to be started.""" + main("maciej-test-demo", "box") From 6385b23f1224748cd9897f7ae8a75d05aabb0139 Mon Sep 17 00:00:00 2001 From: Julia Jia Date: Sat, 20 Sep 2025 22:52:11 -0700 Subject: [PATCH 06/13] Change timeout implementation and add unit test for timeout --- src/rai_core/rai/tools/__init__.py | 3 +- .../rai/tools/ros2/detection/tools.py | 16 ++-- src/rai_core/rai/tools/timeout.py | 77 +++++++++--------- tests/tools/ros2/test_detection_tools.py | 78 +++++++------------ tests/tools/ros2/test_gripping_points.py | 11 ++- 5 files changed, 86 insertions(+), 99 deletions(-) diff --git a/src/rai_core/rai/tools/__init__.py b/src/rai_core/rai/tools/__init__.py index 4b58b900b..a2f5a1099 100644 --- a/src/rai_core/rai/tools/__init__.py +++ b/src/rai_core/rai/tools/__init__.py @@ -12,4 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .timeout import timeout, timeout_method +from .timeout import timeout as timeout +from .timeout import timeout_method as timeout_method diff --git a/src/rai_core/rai/tools/ros2/detection/tools.py b/src/rai_core/rai/tools/ros2/detection/tools.py index c35b0672c..adabc72f6 100644 --- a/src/rai_core/rai/tools/ros2/detection/tools.py +++ b/src/rai_core/rai/tools/ros2/detection/tools.py @@ -22,6 +22,7 @@ PointCloudFilter, PointCloudFromSegmentation, ) +from rai.tools.timeout import TimeoutError, timeout class GetGrippingPointToolInput(BaseModel): @@ -32,7 +33,6 @@ class GetGrippingPointToolInput(BaseModel): # TODO(maciejmajek): Configuration system configurable with namespacing -# TODO(juliajia): Question for Maciej: for comments above on configuration system with namespacing, can you provide an use case for this? class GetGrippingPointTool(BaseROS2Tool): name: str = "get_gripping_point" description: str = "Get gripping points for specified object/objects. Returns 3D coordinates where a robot gripper can grasp the object." @@ -68,10 +68,10 @@ def model_post_init(self, __context: Any) -> None: def _run(self, object_name: str) -> str: # this will be not work in agent scenario because signal need to be run in main thread, comment out for now - # @timeout( - # self.timeout_sec, - # f"Gripping point detection for object '{object_name}' exceeded {self.timeout_sec} seconds", - # ) + @timeout( + self.timeout_sec, + f"Gripping point detection for object '{object_name}' exceeded {self.timeout_sec} seconds", + ) def _run_with_timeout(): pcl = self.point_cloud_from_segmentation.run(object_name) if len(pcl) == 0: @@ -101,7 +101,7 @@ def _run_with_timeout(): try: return _run_with_timeout() - except Exception as e: - if "timed out" in str(e).lower(): - return f"Timeout: Gripping point detection for object '{object_name}' exceeded {self.timeout_sec} seconds" + except TimeoutError: + return f"Timeout: Gripping point detection for object '{object_name}' exceeded {self.timeout_sec} seconds" + except Exception: raise diff --git a/src/rai_core/rai/tools/timeout.py b/src/rai_core/rai/tools/timeout.py index 3f660c7e9..13c788839 100644 --- a/src/rai_core/rai/tools/timeout.py +++ b/src/rai_core/rai/tools/timeout.py @@ -12,7 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -import signal +""" +Design considerations: + +Primary use case: +- 3D object detection pipeline (image → point cloud → segmentation → gripping points) +- Timeout long-running ROS2 service calls from agent tools + +RAI concurrency model: +- `multiprocessing`: Process isolation (ROS2 launch) +- `threading`: Agent execution and callbacks (LangChain agents in worker threads) +- `asyncio`: Limited ROS2 coordination (LaunchManager) + +Timeout implementation: +- Signal-based (SIGALRM): Only works in main thread, unsuitable for RAI's worker threads +- ThreadPoolExecutor: Works in any thread, provides clean resource management + +Alternatives considered: +- asyncio.wait_for(): Requires async context, conflicts with sync tool interface +- threading.Timer: Potential resource leaks, less robust cleanup +""" + +import concurrent.futures from functools import wraps from typing import Any, Callable, TypeVar @@ -63,23 +84,16 @@ def timeout(seconds: float, timeout_message: str = None) -> Callable[[F], F]: def decorator(func: F) -> F: @wraps(func) def wrapper(*args, **kwargs): - def timeout_handler(signum, frame): - message = ( - timeout_message - or f"Function '{func.__name__}' timed out after {seconds} seconds" - ) - raise TimeoutError(message) - - # Set up timeout - old_handler = signal.signal(signal.SIGALRM, timeout_handler) - signal.alarm(int(seconds)) - - try: - return func(*args, **kwargs) - finally: - # Clean up timeout - signal.alarm(0) - signal.signal(signal.SIGALRM, old_handler) + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(func, *args, **kwargs) + try: + return future.result(timeout=seconds) + except concurrent.futures.TimeoutError: + message = ( + timeout_message + or f"Function '{func.__name__}' timed out after {seconds} seconds" + ) + raise TimeoutError(message) return wrapper @@ -116,23 +130,16 @@ def timeout_method(seconds: float, timeout_message: str = None) -> Callable[[F], def decorator(func: F) -> F: @wraps(func) def wrapper(self, *args, **kwargs): - def timeout_handler(signum, frame): - message = ( - timeout_message - or f"Method '{func.__name__}' of {self.__class__.__name__} timed out after {seconds} seconds" - ) - raise TimeoutError(message) - - # Set up timeout - old_handler = signal.signal(signal.SIGALRM, timeout_handler) - signal.alarm(int(seconds)) - - try: - return func(self, *args, **kwargs) - finally: - # Clean up timeout - signal.alarm(0) - signal.signal(signal.SIGALRM, old_handler) + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(func, self, *args, **kwargs) + try: + return future.result(timeout=seconds) + except concurrent.futures.TimeoutError: + message = ( + timeout_message + or f"Method '{func.__name__}' of {self.__class__.__name__} timed out after {seconds} seconds" + ) + raise TimeoutError(message) return wrapper diff --git a/tests/tools/ros2/test_detection_tools.py b/tests/tools/ros2/test_detection_tools.py index b2ecb6e44..22099be5a 100644 --- a/tests/tools/ros2/test_detection_tools.py +++ b/tests/tools/ros2/test_detection_tools.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time + import pytest try: @@ -149,70 +151,42 @@ def test_point_cloud_filter(): def test_get_gripping_point_tool_timeout(): - """Test GetGrippingPointTool timeout behavior.""" - # Mock the connector and components + # Complete mock setup mock_connector = Mock(spec=ROS2Connector) - - # Create mock components that will simulate timeout mock_pcl_gen = Mock(spec=PointCloudFromSegmentation) - mock_pcl_gen.run.side_effect = lambda x: [] # Return empty to simulate no detection - mock_filter = Mock(spec=PointCloudFilter) - mock_filter.run.return_value = [] - mock_estimator = Mock(spec=GrippingPointEstimator) + + # Test 1: No timeout (fast execution) + mock_pcl_gen.run.return_value = [] + mock_filter.run.return_value = [] mock_estimator.run.return_value = [] - # Create tool with short timeout tool = GetGrippingPointTool( connector=mock_connector, - point_cloud_from_segmentation=mock_pcl_gen, - point_cloud_filter=mock_filter, + target_frame="base", + source_frame="camera", + camera_topic="/image", + depth_topic="/depth", + camera_info_topic="/info", gripping_point_estimator=mock_estimator, - timeout_sec=0.1, + point_cloud_filter=mock_filter, + timeout_sec=5.0, ) + tool.point_cloud_from_segmentation = mock_pcl_gen # Connect the mock - # Test successful run with no gripping points found - result = tool._run("test_object") - assert "No gripping point found" in result - assert "test_object" in result - - # Test with mock that simulates found gripping points - mock_estimator.run.return_value = [np.array([1.0, 2.0, 3.0], dtype=np.float32)] - result = tool._run("test_object") - assert "gripping point of the object test_object is" in result - assert "[1. 2. 3.]" in result - - # Test with multiple gripping points - mock_estimator.run.return_value = [ - np.array([1.0, 2.0, 3.0], dtype=np.float32), - np.array([4.0, 5.0, 6.0], dtype=np.float32), - ] + # Test fast execution - should complete without timeout result = tool._run("test_object") - assert "Multiple gripping points found" in result - - -def test_get_gripping_point_tool_validation(): - """Test GetGrippingPointTool input validation.""" - mock_connector = Mock(spec=ROS2Connector) - mock_pcl_gen = Mock(spec=PointCloudFromSegmentation) - mock_filter = Mock(spec=PointCloudFilter) - mock_estimator = Mock(spec=GrippingPointEstimator) - - # Test tool creation - tool = GetGrippingPointTool( - connector=mock_connector, - point_cloud_from_segmentation=mock_pcl_gen, - point_cloud_filter=mock_filter, - gripping_point_estimator=mock_estimator, - ) + assert "No test_objects detected" in result + assert "timed out" not in result.lower() - # Verify tool properties - assert tool.name == "get_gripping_point" - assert "gripping points" in tool.description - assert tool.timeout_sec == 10.0 # default value + # Test 2: Actual timeout behavior + def slow_operation(obj_name): + time.sleep(2.0) # Longer than timeout + return [] - # Test args schema - from rai.tools.ros2.detection.tools import GetGrippingPointToolInput + mock_pcl_gen.run.side_effect = slow_operation + tool.timeout_sec = 1.0 # Short timeout - assert tool.args_schema == GetGrippingPointToolInput + result = tool._run("test") + assert "timed out" in result.lower() or "timeout" in result.lower() diff --git a/tests/tools/ros2/test_gripping_points.py b/tests/tools/ros2/test_gripping_points.py index 64c959d5d..2f7cbc13d 100644 --- a/tests/tools/ros2/test_gripping_points.py +++ b/tests/tools/ros2/test_gripping_points.py @@ -25,6 +25,8 @@ pytest tests/tools/ros2/test_gripping_points.py::test_gripping_points_manipulation_demo -m "" -s -v """ +import time + import cv2 import numpy as np import pytest @@ -278,6 +280,8 @@ def main(config_name: str = "manipulation-demo", test_object: str = "cube"): filter_config = algo_config["filter"] point_cloud_filter = PointCloudFilter(**filter_config) + start_time = time.time() + # Create the tool gripping_tool = GetGrippingPointTool( connector=connector, @@ -288,7 +292,9 @@ def main(config_name: str = "manipulation-demo", test_object: str = "cube"): camera_info_topic=config["topics"]["camera_info"], gripping_point_estimator=gripping_estimator, point_cloud_filter=point_cloud_filter, + timeout_sec=15.0, ) + print(f"elapsed time: {time.time() - start_time} seconds") # Test the tool directly print(f"\nTesting GetGrippingPointTool with object '{test_object}'") @@ -300,6 +306,8 @@ def main(config_name: str = "manipulation-demo", test_object: str = "cube"): for i, gp in enumerate(gripping_points): print(f" GP{i + 1}: [{gp[0]:.3f}, {gp[1]:.3f}, {gp[2]:.3f}]") + assert len(gripping_points) > 0, "No gripping points found" + if gripping_points: # Call the function in pcl.py to publish the gripping point for visualization segmented_clouds = gripping_tool.point_cloud_from_segmentation.run( @@ -319,9 +327,6 @@ def main(config_name: str = "manipulation-demo", test_object: str = "cube"): ) print(f"✅ Saved annotated image as '{annotated_image_path}'") - else: - print("❌ No gripping points found") - except Exception as e: print(f"❌ Setup failed: {e}") import traceback From 9958ca7df6a66e31b3166f76de08154533c87f88 Mon Sep 17 00:00:00 2001 From: Julia Jia Date: Mon, 22 Sep 2025 23:07:43 -0700 Subject: [PATCH 07/13] Parameterize estimator strategy to facilitate manual testing --- tests/conftest.py | 12 ++++++++++ tests/tools/ros2/test_gripping_points.py | 29 ++++++++++++------------ 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index adb9e1850..9c6b25e08 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,6 +18,18 @@ import pytest +# 3D gripping point detection strategy +def pytest_addoption(parser): + parser.addoption( + "--strategy", action="store", default="centroid", help="Gripping point strategy" + ) + + +@pytest.fixture +def strategy(request): + return request.config.getoption("--strategy") + + @pytest.fixture def test_config_toml(): """ diff --git a/tests/tools/ros2/test_gripping_points.py b/tests/tools/ros2/test_gripping_points.py index 2f7cbc13d..6bb90e641 100644 --- a/tests/tools/ros2/test_gripping_points.py +++ b/tests/tools/ros2/test_gripping_points.py @@ -22,7 +22,7 @@ The demo app and rivz2 need to be started before running the test. The test will fail if the gripping points are not found. Usage: -pytest tests/tools/ros2/test_gripping_points.py::test_gripping_points_manipulation_demo -m "" -s -v +pytest tests/tools/ros2/test_gripping_points.py::test_gripping_points_manipulation_demo -m "manual" -s -v --strategy """ import time @@ -246,14 +246,9 @@ def save_annotated_image( cv2.imwrite(filename, annotated_image) -def main(config_name: str = "manipulation-demo", test_object: str = "cube"): +def main(config: dict, test_object: str = "cube", strategy: str = None): """Enhanced test with visualization and better error handling.""" - # Get test configuration - config = TEST_CONFIGS[config_name] - - print(f"Config: {config_name}") - # Initialize ROS2 rclpy.init() connector = ROS2Connector(executor_type="single_threaded") @@ -273,7 +268,9 @@ def main(config_name: str = "manipulation-demo", test_object: str = "cube"): algo_config = config["algorithms"] # Create gripping estimator with strategy-specific parameters - estimator_config = algo_config["estimator"] + estimator_config = algo_config["estimator"].copy() + if strategy: + estimator_config["strategy"] = strategy gripping_estimator = GrippingPointEstimator(**estimator_config) # Create point cloud filter @@ -297,7 +294,9 @@ def main(config_name: str = "manipulation-demo", test_object: str = "cube"): print(f"elapsed time: {time.time() - start_time} seconds") # Test the tool directly - print(f"\nTesting GetGrippingPointTool with object '{test_object}'") + print( + f"\nTesting GetGrippingPointTool with object '{test_object}', strategy '{strategy}'" + ) result = gripping_tool._run(test_object) gripping_points = extract_gripping_points(result) @@ -321,7 +320,7 @@ def main(config_name: str = "manipulation-demo", test_object: str = "cube"): ) print("✅ Debug data published") - annotated_image_path = f"{test_object}_gripping_points.jpg" + annotated_image_path = f"{test_object}_{strategy}_gripping_points.jpg" save_annotated_image( connector, gripping_points, config, annotated_image_path ) @@ -340,12 +339,14 @@ def main(config_name: str = "manipulation-demo", test_object: str = "cube"): @pytest.mark.manual -def test_gripping_points_manipulation_demo(): +def test_gripping_points_manipulation_demo(strategy): """Manual test requiring manipulation-demo app to be started.""" - main("manipulation-demo", "apple") + config = TEST_CONFIGS["manipulation-demo"] + main(config, "cube", strategy) @pytest.mark.manual -def test_gripping_points_maciej_demo(): +def test_gripping_points_maciej_demo(strategy): """Manual test requiring demo app to be started.""" - main("maciej-test-demo", "box") + config = TEST_CONFIGS["maciej-test-demo"] + main(config, "box", strategy) From ea865e18df431035953bde8a9a42f6e8ec0660db Mon Sep 17 00:00:00 2001 From: Julia Jia Date: Tue, 23 Sep 2025 00:28:50 -0700 Subject: [PATCH 08/13] Support configuration for detection pipeline --- examples/manipulation-demo-v2.py | 128 +++++++++++ src/rai_core/rai/tools/ros2/detection/pcl.py | 165 ++++++-------- .../rai/tools/ros2/detection/tools.py | 93 ++++++-- tests/tools/ros2/test_detection_tools.py | 38 ++-- tests/tools/ros2/test_gripping_points.py | 208 ++++++++---------- 5 files changed, 391 insertions(+), 241 deletions(-) create mode 100644 examples/manipulation-demo-v2.py diff --git a/examples/manipulation-demo-v2.py b/examples/manipulation-demo-v2.py new file mode 100644 index 000000000..f384a82b3 --- /dev/null +++ b/examples/manipulation-demo-v2.py @@ -0,0 +1,128 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language goveself.rning permissions and +# limitations under the License. + + +import logging +from typing import List + +import rclpy +import rclpy.qos +from langchain_core.messages import BaseMessage, HumanMessage +from langchain_core.tools import BaseTool +from rai import get_llm_model +from rai.agents.langchain.core import create_conversational_agent +from rai.communication.ros2 import wait_for_ros2_services, wait_for_ros2_topics +from rai.communication.ros2.connectors import ROS2Connector +from rai.tools.ros2.detection.pcl import ( + GrippingPointEstimatorConfig, + PointCloudFilterConfig, + PointCloudFromSegmentationConfig, +) +from rai.tools.ros2.detection.tools import GetGrippingPointTool +from rai.tools.ros2.manipulation import ( + MoveObjectFromToTool, + ResetArmTool, +) +from rai.tools.ros2.simple import GetROS2ImageConfiguredTool + +from rai_whoami.models import EmbodimentInfo + +logger = logging.getLogger(__name__) + + +def create_agent(): + rclpy.init() + connector = ROS2Connector(executor_type="single_threaded") + + required_services = ["/grounded_sam_segment", "/grounding_dino_classify"] + required_topics = ["/color_image5", "/depth_image5", "/color_camera_info5"] + wait_for_ros2_services(connector, required_services) + wait_for_ros2_topics(connector, required_topics) + + node = connector.node + + # Declare and set parameters for GetGrippingPointTool + # These also can be set in the launch file or during runtime + parameters_to_set = [ + ("conversion_ratio", 1.0), + ("detection_tools.gripping_point.target_frame", "panda_link0"), + ("detection_tools.gripping_point.source_frame", "RGBDCamera5"), + ("detection_tools.gripping_point.camera_topic", "/color_image5"), + ("detection_tools.gripping_point.depth_topic", "/depth_image5"), + ("detection_tools.gripping_point.camera_info_topic", "/color_camera_info5"), + ] + + # Declare and set each parameter (timeout_sec handled by tool internally) + for param_name, param_value in parameters_to_set: + node.declare_parameter(param_name, param_value) + + # Configure gripping point detection algorithms + segmentation_config = PointCloudFromSegmentationConfig( + box_threshold=0.35, + text_threshold=0.45, + ) + + estimator_config = GrippingPointEstimatorConfig( + strategy="biggest_plane", # Options: "centroid", "top_plane", "biggest_plane" + top_percentile=0.05, + plane_bin_size_m=0.01, + ransac_iterations=200, + distance_threshold_m=0.01, + min_points=10, + ) + + filter_config = PointCloudFilterConfig( + strategy="dbscan", + min_points=20, + dbscan_eps=0.02, + dbscan_min_samples=10, + ) + + tools: List[BaseTool] = [ + GetGrippingPointTool( + connector=connector, + segmentation_config=segmentation_config, + estimator_config=estimator_config, + filter_config=filter_config, + ), + MoveObjectFromToTool(connector=connector, manipulator_frame="panda_link0"), + ResetArmTool(connector=connector, manipulator_frame="panda_link0"), + GetROS2ImageConfiguredTool(connector=connector, topic="/color_image5"), + ] + + llm = get_llm_model(model_type="complex_model", streaming=True) + embodiment_info = EmbodimentInfo.from_file( + "examples/embodiments/manipulation_embodiment.json" + ) + agent = create_conversational_agent( + llm=llm, + tools=tools, + system_prompt=embodiment_info.to_langchain(), + ) + return agent + + +def main(): + agent = create_agent() + messages: List[BaseMessage] = [] + + while True: + prompt = input("Enter a prompt: ") + messages.append(HumanMessage(content=prompt)) + output = agent.invoke({"messages": messages}) + output["messages"][-1].pretty_print() + + +if __name__ == "__main__": + main() diff --git a/src/rai_core/rai/tools/ros2/detection/pcl.py b/src/rai_core/rai/tools/ros2/detection/pcl.py index 10dd2884b..2adee2ff3 100644 --- a/src/rai_core/rai/tools/ros2/detection/pcl.py +++ b/src/rai_core/rai/tools/ros2/detection/pcl.py @@ -17,6 +17,7 @@ import numpy as np import sensor_msgs.msg from numpy.typing import NDArray +from pydantic import BaseModel from rai_open_set_vision import GDINO_SERVICE_NAME from rclpy import Future from rclpy.exceptions import ( @@ -32,6 +33,38 @@ from rai_interfaces.srv import RAIGroundedSam, RAIGroundingDino +class PointCloudFromSegmentationConfig(BaseModel): + box_threshold: float = 0.35 + text_threshold: float = 0.45 + + +class GrippingPointEstimatorConfig(BaseModel): + strategy: Literal["centroid", "top_plane", "biggest_plane"] = "centroid" + top_percentile: float = 0.05 + plane_bin_size_m: float = 0.01 + ransac_iterations: int = 200 + distance_threshold_m: float = 0.01 + min_points: int = 10 + + +class PointCloudFilterConfig(BaseModel): + strategy: Literal["dbscan", "kmeans_largest_cluster", "isolation_forest", "lof"] = ( + "dbscan" + ) + min_points: int = 20 + # DBSCAN + dbscan_eps: float = 0.02 + dbscan_min_samples: int = 10 + # KMeans + kmeans_k: int = 2 + # Isolation Forest + if_max_samples: int | float | Literal["auto"] = "auto" + if_contamination: float = 0.05 + # LOF + lof_n_neighbors: int = 20 + lof_contamination: float = 0.05 + + def depth_to_point_cloud( depth_image: NDArray[np.float32], fx: float, fy: float, cx: float, cy: float ) -> NDArray[np.float32]: @@ -125,16 +158,6 @@ class PointCloudFromSegmentation: get an Nx3 numpy array of points [X, Y, Z] expressed in the target frame. """ - connector: ROS2Connector - camera_topic: str - depth_topic: str - camera_info_topic: str - source_frame: str - target_frame: str - - box_threshold: float = 0.35 - text_threshold: float = 0.45 - def __init__( self, *, @@ -144,8 +167,7 @@ def __init__( camera_info_topic: str, source_frame: str, target_frame: str, - box_threshold: float = 0.35, - text_threshold: float = 0.45, + config: PointCloudFromSegmentationConfig, ) -> None: self.connector = connector self.camera_topic = camera_topic @@ -153,8 +175,7 @@ def __init__( self.camera_info_topic = camera_info_topic self.source_frame = source_frame self.target_frame = target_frame - self.box_threshold = box_threshold - self.text_threshold = text_threshold + self.config = config # --------------------- ROS helpers --------------------- def _get_image_message(self, topic: str) -> sensor_msgs.msg.Image: @@ -194,8 +215,8 @@ def _call_gdino_node( req = RAIGroundingDino.Request() req.source_img = camera_img_message req.classes = object_name - req.box_threshold = self.box_threshold - req.text_threshold = self.text_threshold + req.box_threshold = self.config.box_threshold + req.text_threshold = self.config.text_threshold return cli.call_async(req) def _call_gsam_node( @@ -330,29 +351,8 @@ class GrippingPointEstimator: - "biggest_plane": centroid of the most populated horizontal plane bin (RANSAC-free) """ - strategy: Literal["centroid", "top_plane", "biggest_plane"] - top_percentile: float - plane_bin_size_m: float - ransac_iterations: int - distance_threshold_m: float - min_points: int - - def __init__( - self, - *, - strategy: Literal["centroid", "top_plane", "biggest_plane"] = "centroid", - top_percentile: float = 0.05, - plane_bin_size_m: float = 0.01, - ransac_iterations: int = 200, - distance_threshold_m: float = 0.01, - min_points: int = 10, - ) -> None: - self.strategy = strategy - self.top_percentile = top_percentile - self.plane_bin_size_m = plane_bin_size_m - self.ransac_iterations = int(max(1, ransac_iterations)) - self.distance_threshold_m = float(max(1e-6, distance_threshold_m)) - self.min_points = min_points + def __init__(self, config: GrippingPointEstimatorConfig) -> None: + self.config = config def _centroid(self, points: NDArray[np.float32]) -> Optional[NDArray[np.float32]]: if points.size == 0: @@ -362,10 +362,10 @@ def _centroid(self, points: NDArray[np.float32]) -> Optional[NDArray[np.float32] def _top_plane_centroid( self, points: NDArray[np.float32] ) -> Optional[NDArray[np.float32]]: - if points.shape[0] < self.min_points: + if points.shape[0] < self.config.min_points: return self._centroid(points) z_vals = points[:, 2] - threshold = np.quantile(z_vals, 1.0 - self.top_percentile) + threshold = np.quantile(z_vals, 1.0 - self.config.top_percentile) mask = z_vals >= threshold top_points = points[mask] if top_points.shape[0] == 0: @@ -377,7 +377,7 @@ def _biggest_plane_centroid( ) -> Optional[NDArray[np.float32]]: # RANSAC plane detection: not restricted to horizontal planes num_points = points.shape[0] - if num_points < self.min_points: + if num_points < self.config.min_points: return self._centroid(points) best_inlier_count = 0 @@ -385,11 +385,11 @@ def _biggest_plane_centroid( # Precompute for speed pts64 = points.astype(np.float64, copy=False) - threshold = float(self.distance_threshold_m) + threshold = float(self.config.distance_threshold_m) rng = np.random.default_rng() - for _ in range(self.ransac_iterations): + for _ in range(self.config.ransac_iterations): # Sample 3 unique points idxs = rng.choice(num_points, size=3, replace=False) p0, p1, p2 = pts64[idxs[0]], pts64[idxs[1]], pts64[idxs[2]] @@ -410,7 +410,7 @@ def _biggest_plane_centroid( best_inlier_count = count best_inlier_mask = inliers - if best_inlier_mask is None or best_inlier_count < self.min_points: + if best_inlier_mask is None or best_inlier_count < self.config.min_points: return self._centroid(points) inlier_points = points[best_inlier_mask] @@ -436,11 +436,11 @@ def run( for pts in segmented_point_clouds: if pts.size == 0: continue - if self.strategy == "centroid": + if self.config.strategy == "centroid": gp = self._centroid(pts) - elif self.strategy == "top_plane": + elif self.config.strategy == "top_plane": gp = self._top_plane_centroid(pts) - elif self.strategy == "biggest_plane": + elif self.config.strategy == "biggest_plane": gp = self._biggest_plane_centroid(pts) else: gp = self._centroid(pts) @@ -461,51 +461,17 @@ class PointCloudFilter: - "lof": keep inliers (pred == 1) """ - strategy: Literal["dbscan", "kmeans_largest_cluster", "isolation_forest", "lof"] - min_points: int - # DBSCAN - dbscan_eps: float - dbscan_min_samples: int - # KMeans - kmeans_k: int - # Isolation Forest - if_max_samples: int | float | Literal["auto"] - if_contamination: float - # LOF - lof_n_neighbors: int - lof_contamination: float - - def __init__( - self, - *, - strategy: Literal[ - "dbscan", "kmeans_largest_cluster", "isolation_forest", "lof" - ] = "dbscan", - min_points: int = 20, - dbscan_eps: float = 0.02, - dbscan_min_samples: int = 10, - kmeans_k: int = 2, - if_max_samples: int | float | Literal["auto"] = "auto", - if_contamination: float = 0.05, - lof_n_neighbors: int = 20, - lof_contamination: float = 0.05, - ) -> None: - self.strategy = strategy - self.min_points = min_points - self.dbscan_eps = dbscan_eps - self.dbscan_min_samples = dbscan_min_samples - self.kmeans_k = kmeans_k - self.if_max_samples = if_max_samples - self.if_contamination = if_contamination - self.lof_n_neighbors = lof_n_neighbors - self.lof_contamination = lof_contamination + def __init__(self, config: PointCloudFilterConfig) -> None: + self.config = config def _filter_dbscan(self, pts: NDArray[np.float32]) -> NDArray[np.float32]: from sklearn.cluster import DBSCAN # type: ignore[reportMissingImports] - if pts.shape[0] < self.min_points: + if pts.shape[0] < self.config.min_points: return pts - db = DBSCAN(eps=self.dbscan_eps, min_samples=self.dbscan_min_samples) + db = DBSCAN( + eps=self.config.dbscan_eps, min_samples=self.config.dbscan_min_samples + ) labels = cast(NDArray[np.int64], db.fit_predict(pts)) # type: ignore[no-any-return] if labels.size == 0: return pts @@ -521,9 +487,9 @@ def _filter_dbscan(self, pts: NDArray[np.float32]) -> NDArray[np.float32]: def _filter_kmeans_largest(self, pts: NDArray[np.float32]) -> NDArray[np.float32]: from sklearn.cluster import KMeans # type: ignore[reportMissingImports] - if pts.shape[0] < max(self.min_points, self.kmeans_k): + if pts.shape[0] < max(self.config.min_points, self.config.kmeans_k): return pts - kmeans = KMeans(n_clusters=self.kmeans_k, n_init="auto") + kmeans = KMeans(n_clusters=self.config.kmeans_k, n_init="auto") labels = cast(NDArray[np.int64], kmeans.fit_predict(pts)) # type: ignore[no-any-return] unique_labels, counts = np.unique(labels, return_counts=True) dominant = unique_labels[np.argmax(counts)] @@ -535,11 +501,11 @@ def _filter_isolation_forest(self, pts: NDArray[np.float32]) -> NDArray[np.float IsolationForest, # type: ignore[reportMissingImports] ) - if pts.shape[0] < self.min_points: + if pts.shape[0] < self.config.min_points: return pts iso = IsolationForest( - max_samples=self.if_max_samples, - contamination=self.if_contamination, + max_samples=self.config.if_max_samples, + contamination=self.config.if_contamination, random_state=42, ) pred = cast(NDArray[np.int64], iso.fit_predict(pts)) # type: ignore[no-any-return] # 1 inlier, -1 outlier @@ -553,10 +519,11 @@ def _filter_lof(self, pts: NDArray[np.float32]) -> NDArray[np.float32]: LocalOutlierFactor, # type: ignore[reportMissingImports] ) - if pts.shape[0] < max(self.min_points, self.lof_n_neighbors + 1): + if pts.shape[0] < max(self.config.min_points, self.config.lof_n_neighbors + 1): return pts lof = LocalOutlierFactor( - n_neighbors=self.lof_n_neighbors, contamination=self.lof_contamination + n_neighbors=self.config.lof_n_neighbors, + contamination=self.config.lof_contamination, ) pred = cast(NDArray[np.int64], lof.fit_predict(pts)) # type: ignore[no-any-return] # 1 inlier, -1 outlier mask = pred == 1 @@ -571,13 +538,13 @@ def run( for pts in segmented_point_clouds: if pts.size == 0: continue - if self.strategy == "dbscan": + if self.config.strategy == "dbscan": f = self._filter_dbscan(pts) - elif self.strategy == "kmeans_largest_cluster": + elif self.config.strategy == "kmeans_largest_cluster": f = self._filter_kmeans_largest(pts) - elif self.strategy == "isolation_forest": + elif self.config.strategy == "isolation_forest": f = self._filter_isolation_forest(pts) - elif self.strategy == "lof": + elif self.config.strategy == "lof": f = self._filter_lof(pts) else: f = pts diff --git a/src/rai_core/rai/tools/ros2/detection/tools.py b/src/rai_core/rai/tools/ros2/detection/tools.py index adabc72f6..04feb7bee 100644 --- a/src/rai_core/rai/tools/ros2/detection/tools.py +++ b/src/rai_core/rai/tools/ros2/detection/tools.py @@ -19,8 +19,11 @@ from rai.tools.ros2.base import BaseROS2Tool from rai.tools.ros2.detection.pcl import ( GrippingPointEstimator, + GrippingPointEstimatorConfig, PointCloudFilter, + PointCloudFilterConfig, PointCloudFromSegmentation, + PointCloudFromSegmentationConfig, ) from rai.tools.timeout import TimeoutError, timeout @@ -32,31 +35,84 @@ class GetGrippingPointToolInput(BaseModel): ) -# TODO(maciejmajek): Configuration system configurable with namespacing class GetGrippingPointTool(BaseROS2Tool): name: str = "get_gripping_point" description: str = "Get gripping points for specified object/objects. Returns 3D coordinates where a robot gripper can grasp the object." - target_frame: str - source_frame: str - camera_topic: str # rgb camera topic - depth_topic: str - camera_info_topic: str # rgb camera info topic - - gripping_point_estimator: GrippingPointEstimator - point_cloud_filter: PointCloudFilter - - # Auto-initialized in model_post_init + # Configuration for PCL components + segmentation_config: PointCloudFromSegmentationConfig + estimator_config: GrippingPointEstimatorConfig + filter_config: PointCloudFilterConfig + + # Auto-initialized in model_post_init from ROS2 parameters + target_frame: Optional[str] = None + source_frame: Optional[str] = None + camera_topic: Optional[str] = None + depth_topic: Optional[str] = None + camera_info_topic: Optional[str] = None + timeout_sec: Optional[float] = None + + # Components initialized in model_post_init + gripping_point_estimator: Optional[GrippingPointEstimator] = None + point_cloud_filter: Optional[PointCloudFilter] = None point_cloud_from_segmentation: Optional[PointCloudFromSegmentation] = None - timeout_sec: float = Field( - default=10.0, description="Timeout in seconds to get the gripping point" - ) - args_schema: Type[GetGrippingPointToolInput] = GetGrippingPointToolInput def model_post_init(self, __context: Any) -> None: - """Initialize PointCloudFromSegmentation with the provided camera parameters.""" + """Initialize tool with ROS2 parameters and components.""" + self._load_parameters() + self._initialize_components() + + def _load_parameters(self) -> None: + """Load configuration from ROS2 parameters.""" + node = self.connector.node + param_prefix = "detection_tools.gripping_point" + + # Declare required parameters + required_params = [ + f"{param_prefix}.target_frame", + f"{param_prefix}.source_frame", + f"{param_prefix}.camera_topic", + f"{param_prefix}.depth_topic", + f"{param_prefix}.camera_info_topic", + ] + + for param_name in required_params: + if not node.has_parameter(param_name): + raise ValueError( + f"Required parameter '{param_name}' must be set before initializing GetGrippingPointTool" + ) + + # Optional parameter with default + node.declare_parameter(f"{param_prefix}.timeout_sec", 10.0) + + # Load parameters + self.target_frame = node.get_parameter(f"{param_prefix}.target_frame").value + self.source_frame = node.get_parameter(f"{param_prefix}.source_frame").value + self.camera_topic = node.get_parameter(f"{param_prefix}.camera_topic").value + self.depth_topic = node.get_parameter(f"{param_prefix}.depth_topic").value + self.camera_info_topic = node.get_parameter( + f"{param_prefix}.camera_info_topic" + ).value + self.timeout_sec = node.get_parameter(f"{param_prefix}.timeout_sec").value + + # Validate required parameters are not empty + if not all( + [ + self.target_frame, + self.source_frame, + self.camera_topic, + self.depth_topic, + self.camera_info_topic, + ] + ): + raise ValueError( + "Required ROS2 parameters for GetGrippingPointTool cannot be empty" + ) + + def _initialize_components(self) -> None: + """Initialize PCL components with loaded parameters.""" self.point_cloud_from_segmentation = PointCloudFromSegmentation( connector=self.connector, camera_topic=self.camera_topic, @@ -64,7 +120,12 @@ def model_post_init(self, __context: Any) -> None: camera_info_topic=self.camera_info_topic, source_frame=self.source_frame, target_frame=self.target_frame, + config=self.segmentation_config, + ) + self.gripping_point_estimator = GrippingPointEstimator( + config=self.estimator_config ) + self.point_cloud_filter = PointCloudFilter(config=self.filter_config) def _run(self, object_name: str) -> str: # this will be not work in agent scenario because signal need to be run in main thread, comment out for now diff --git a/tests/tools/ros2/test_detection_tools.py b/tests/tools/ros2/test_detection_tools.py index 22099be5a..914abfdc9 100644 --- a/tests/tools/ros2/test_detection_tools.py +++ b/tests/tools/ros2/test_detection_tools.py @@ -30,8 +30,11 @@ from rai.tools.ros2.detection import GetGrippingPointTool from rai.tools.ros2.detection.pcl import ( GrippingPointEstimator, + GrippingPointEstimatorConfig, PointCloudFilter, + PointCloudFilterConfig, PointCloudFromSegmentation, + PointCloudFromSegmentationConfig, depth_to_point_cloud, ) @@ -92,7 +95,9 @@ def test_gripping_point_estimator(): segmented_clouds = [points1, points2] # Test centroid strategy - estimator = GrippingPointEstimator(strategy="centroid") + estimator = GrippingPointEstimator( + config=GrippingPointEstimatorConfig(strategy="centroid") + ) grip_points = estimator.run(segmented_clouds) assert len(grip_points) == 2 @@ -101,7 +106,9 @@ def test_gripping_point_estimator(): np.testing.assert_array_almost_equal(grip_points[0], expected_centroid1) # Test top_plane strategy - estimator_top = GrippingPointEstimator(strategy="top_plane", top_percentile=0.5) + estimator_top = GrippingPointEstimator( + config=GrippingPointEstimatorConfig(strategy="top_plane", top_percentile=0.5) + ) grip_points_top = estimator_top.run(segmented_clouds) assert len(grip_points_top) == 2 @@ -125,7 +132,9 @@ def test_point_cloud_filter(): # Test DBSCAN filtering filter_dbscan = PointCloudFilter( - strategy="dbscan", dbscan_eps=0.5, dbscan_min_samples=5 + config=PointCloudFilterConfig( + strategy="dbscan", dbscan_eps=0.5, dbscan_min_samples=5 + ) ) filtered_dbscan = filter_dbscan.run(clouds) @@ -136,14 +145,18 @@ def test_point_cloud_filter(): # Test with too few points (should return original) small_cloud = np.array([[1, 1, 1], [2, 2, 2]], dtype=np.float32) - filter_small = PointCloudFilter(strategy="dbscan", min_points=20) + filter_small = PointCloudFilter( + config=PointCloudFilterConfig(strategy="dbscan", min_points=20) + ) filtered_small = filter_small.run([small_cloud]) assert len(filtered_small) == 1 np.testing.assert_array_equal(filtered_small[0], small_cloud) # Test kmeans_largest_cluster strategy - filter_kmeans = PointCloudFilter(strategy="kmeans_largest_cluster", kmeans_k=2) + filter_kmeans = PointCloudFilter( + config=PointCloudFilterConfig(strategy="kmeans_largest_cluster", kmeans_k=2) + ) filtered_kmeans = filter_kmeans.run(clouds) assert len(filtered_kmeans) == 1 @@ -164,15 +177,14 @@ def test_get_gripping_point_tool_timeout(): tool = GetGrippingPointTool( connector=mock_connector, - target_frame="base", - source_frame="camera", - camera_topic="/image", - depth_topic="/depth", - camera_info_topic="/info", - gripping_point_estimator=mock_estimator, - point_cloud_filter=mock_filter, - timeout_sec=5.0, + segmentation_config=PointCloudFromSegmentationConfig(), + estimator_config=GrippingPointEstimatorConfig(), + filter_config=PointCloudFilterConfig(), ) + # Mock the initialized components + tool.gripping_point_estimator = mock_estimator + tool.point_cloud_filter = mock_filter + tool.timeout_sec = 5.0 tool.point_cloud_from_segmentation = mock_pcl_gen # Connect the mock # Test fast execution - should complete without timeout diff --git a/tests/tools/ros2/test_gripping_points.py b/tests/tools/ros2/test_gripping_points.py index 6bb90e641..7c1106af6 100644 --- a/tests/tools/ros2/test_gripping_points.py +++ b/tests/tools/ros2/test_gripping_points.py @@ -31,91 +31,18 @@ import numpy as np import pytest import rclpy +import rclpy.parameter from cv_bridge import CvBridge from rai.communication.ros2 import wait_for_ros2_services, wait_for_ros2_topics from rai.communication.ros2.connectors import ROS2Connector from rai.tools.ros2.detection import GetGrippingPointTool from rai.tools.ros2.detection.pcl import ( - GrippingPointEstimator, - PointCloudFilter, + GrippingPointEstimatorConfig, + PointCloudFilterConfig, + PointCloudFromSegmentationConfig, _publish_gripping_point_debug_data, ) -# Test configurations -TEST_CONFIGS = { - "manipulation-demo": { - "services": ["/grounded_sam_segment", "/grounding_dino_classify"], - "topics": { - "color_image": "/color_image5", - "depth_image": "/depth_image5", - "camera_info": "/color_camera_info5", - }, - "frames": {"target": "panda_link0", "source": "RGBDCamera5"}, - "algorithms": { - "filter": { - "strategy": "dbscan", - "dbscan_eps": 0.02, - "dbscan_min_samples": 5, - }, - "estimator": {"strategy": "centroid"}, - }, - }, - "maciej-test-demo": { - "services": ["/grounded_sam_segment", "/grounding_dino_classify"], - "topics": { - "color_image": "/rgbd_camera/camera_image_color", - "depth_image": "/rgbd_camera/camera_image_depth", - "camera_info": "/rgbd_camera/camera_info", - }, - "frames": { - "target": "egoarm_base_link", - "source": "egofront_rgbd_camera_depth_optical_frame", - }, - "algorithms": { - "filter": { - "strategy": "dbscan", - "dbscan_eps": 0.02, - "dbscan_min_samples": 10, - }, - "estimator": { - "strategy": "biggest_plane", - "ransac_iterations": 400, - "distance_threshold_m": 0.008, - }, - }, - }, - "dummy-example-with-default-algorithm-parameters": { - "services": ["/grounded_sam_segment", "/grounding_dino_classify"], - "topics": { - "color_image": "/color_image5", - "depth_image": "/depth_image5", - "camera_info": "/color_camera_info5", - }, - "frames": {"target": "panda_link0", "source": "RGBDCamera5"}, - "algorithms": { - "filter": { - "strategy": "dbscan", - "min_points": 100, - "dbscan_eps": 0.02, - "dbscan_min_samples": 5, - "kmeans_k": 3, - "if_max_samples": 100, - "if_contamination": 0.1, - "lof_n_neighbors": 20, - "lof_contamination": 0.1, - }, - "estimator": { - "strategy": "centroid", - "top_percentile": 0.8, - "plane_bin_size_m": 0.01, - "ransac_iterations": 100, - "distance_threshold_m": 0.01, - "min_points": 10, - }, - }, - }, -} - def draw_points_on_image(image_msg, points, camera_info): """Draw points on the camera image.""" @@ -225,18 +152,24 @@ def transform_points_to_target_frame(connector, points, source_frame, target_fra def save_annotated_image( - connector, gripping_points, config, filename: str = "gripping_points_annotated.jpg" + connector, + gripping_points, + camera_topic, + camera_info_topic, + source_frame, + target_frame, + filename: str = "gripping_points_annotated.jpg", ): camera_frame_points = transform_points_to_target_frame( connector, gripping_points, - config["frames"]["source"], - config["frames"]["target"], + source_frame, + target_frame, ) # Get current camera image and draw points - image_msg = connector.receive_message(config["topics"]["color_image"]).payload - camera_info_msg = connector.receive_message(config["topics"]["camera_info"]).payload + image_msg = connector.receive_message(camera_topic).payload + camera_info_msg = connector.receive_message(camera_info_topic).payload # Draw gripping points on image annotated_image = draw_points_on_image( @@ -246,8 +179,36 @@ def save_annotated_image( cv2.imwrite(filename, annotated_image) -def main(config: dict, test_object: str = "cube", strategy: str = None): - """Enhanced test with visualization and better error handling.""" +def main( + test_object: str = "cube", + strategy: str = "centroid", + topics: dict = None, + frames: dict = None, + estimator_config: dict = None, + filter_config: dict = None, +): + # Default configuration for manipulation-demo + if topics is None: + topics = { + "camera": "/color_image5", + "depth": "/depth_image5", + "camera_info": "/color_camera_info5", + } + + if frames is None: + frames = {"target": "panda_link0", "source": "RGBDCamera5"} + + if estimator_config is None: + estimator_config = {"strategy": strategy} + + if filter_config is None: + filter_config = { + "strategy": "dbscan", + "dbscan_eps": 0.02, + "dbscan_min_samples": 5, + } + + services = ["/grounded_sam_segment", "/grounding_dino_classify"] # Initialize ROS2 rclpy.init() @@ -256,40 +217,35 @@ def main(config: dict, test_object: str = "cube", strategy: str = None): try: # Wait for required services and topics print("Waiting for ROS2 services and topics...") - wait_for_ros2_services(connector, config["services"]) - wait_for_ros2_topics(connector, list(config["topics"].values())) + wait_for_ros2_services(connector, services) + wait_for_ros2_topics(connector, list(topics.values())) print("✅ All services and topics available") # Set up node parameters node = connector.node - node.declare_parameter("conversion_ratio", 1.0) - - # Create tool components - algo_config = config["algorithms"] - # Create gripping estimator with strategy-specific parameters - estimator_config = algo_config["estimator"].copy() - if strategy: - estimator_config["strategy"] = strategy - gripping_estimator = GrippingPointEstimator(**estimator_config) + # Declare and set ROS2 parameters for deployment configuration + parameters_to_set = [ + ("conversion_ratio", 1.0), + ("detection_tools.gripping_point.target_frame", frames["target"]), + ("detection_tools.gripping_point.source_frame", frames["source"]), + ("detection_tools.gripping_point.camera_topic", topics["camera"]), + ("detection_tools.gripping_point.depth_topic", topics["depth"]), + ("detection_tools.gripping_point.camera_info_topic", topics["camera_info"]), + ] - # Create point cloud filter - filter_config = algo_config["filter"] - point_cloud_filter = PointCloudFilter(**filter_config) + # Declare and set each parameter + for param_name, param_value in parameters_to_set: + node.declare_parameter(param_name, param_value) start_time = time.time() - # Create the tool + # Create the tool with algorithm configurations gripping_tool = GetGrippingPointTool( connector=connector, - target_frame=config["frames"]["target"], - source_frame=config["frames"]["source"], - camera_topic=config["topics"]["color_image"], - depth_topic=config["topics"]["depth_image"], - camera_info_topic=config["topics"]["camera_info"], - gripping_point_estimator=gripping_estimator, - point_cloud_filter=point_cloud_filter, - timeout_sec=15.0, + segmentation_config=PointCloudFromSegmentationConfig(), + estimator_config=GrippingPointEstimatorConfig(**estimator_config), + filter_config=PointCloudFilterConfig(**filter_config), ) print(f"elapsed time: {time.time() - start_time} seconds") @@ -316,13 +272,19 @@ def main(config: dict, test_object: str = "cube", strategy: str = None): "\nPublishing debug data to /debug_gripping_points_pointcloud and /debug_gripping_points_markerarray" ) _publish_gripping_point_debug_data( - connector, segmented_clouds, gripping_points, config["frames"]["target"] + connector, segmented_clouds, gripping_points, frames["target"] ) print("✅ Debug data published") annotated_image_path = f"{test_object}_{strategy}_gripping_points.jpg" save_annotated_image( - connector, gripping_points, config, annotated_image_path + connector, + gripping_points, + topics["camera"], + topics["camera_info"], + frames["source"], + frames["target"], + annotated_image_path, ) print(f"✅ Saved annotated image as '{annotated_image_path}'") @@ -341,12 +303,32 @@ def main(config: dict, test_object: str = "cube", strategy: str = None): @pytest.mark.manual def test_gripping_points_manipulation_demo(strategy): """Manual test requiring manipulation-demo app to be started.""" - config = TEST_CONFIGS["manipulation-demo"] - main(config, "cube", strategy) + main("cube", strategy) @pytest.mark.manual def test_gripping_points_maciej_demo(strategy): """Manual test requiring demo app to be started.""" - config = TEST_CONFIGS["maciej-test-demo"] - main(config, "box", strategy) + main( + test_object="box", + strategy=strategy, + topics={ + "camera": "/rgbd_camera/camera_image_color", + "depth": "/rgbd_camera/camera_image_depth", + "camera_info": "/rgbd_camera/camera_info", + }, + frames={ + "target": "egoarm_base_link", + "source": "egofront_rgbd_camera_depth_optical_frame", + }, + estimator_config={ + "strategy": strategy or "biggest_plane", + "ransac_iterations": 400, + "distance_threshold_m": 0.008, + }, + filter_config={ + "strategy": "dbscan", + "dbscan_eps": 0.02, + "dbscan_min_samples": 10, + }, + ) From 9d289d66f584b0063f17a6d80271c306b86bbce8 Mon Sep 17 00:00:00 2001 From: Julia Jia Date: Tue, 23 Sep 2025 10:58:22 -0700 Subject: [PATCH 09/13] Set isolation_forest as default for Point Cloud filtering and publish filtered pcl in test --- src/rai_core/rai/tools/ros2/detection/pcl.py | 2 +- tests/tools/ros2/test_gripping_points.py | 28 +++++++++++--------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/rai_core/rai/tools/ros2/detection/pcl.py b/src/rai_core/rai/tools/ros2/detection/pcl.py index 2adee2ff3..9c671561b 100644 --- a/src/rai_core/rai/tools/ros2/detection/pcl.py +++ b/src/rai_core/rai/tools/ros2/detection/pcl.py @@ -49,7 +49,7 @@ class GrippingPointEstimatorConfig(BaseModel): class PointCloudFilterConfig(BaseModel): strategy: Literal["dbscan", "kmeans_largest_cluster", "isolation_forest", "lof"] = ( - "dbscan" + "isolation_forest" ) min_points: int = 20 # DBSCAN diff --git a/tests/tools/ros2/test_gripping_points.py b/tests/tools/ros2/test_gripping_points.py index 7c1106af6..a3dad0bf6 100644 --- a/tests/tools/ros2/test_gripping_points.py +++ b/tests/tools/ros2/test_gripping_points.py @@ -203,9 +203,9 @@ def main( if filter_config is None: filter_config = { - "strategy": "dbscan", - "dbscan_eps": 0.02, - "dbscan_min_samples": 5, + "strategy": "isolation_forest", + "if_max_samples": "auto", + "if_contamination": 0.05, } services = ["/grounded_sam_segment", "/grounding_dino_classify"] @@ -240,6 +240,10 @@ def main( start_time = time.time() + print( + f"\nTesting GetGrippingPointTool with object '{test_object}', strategy '{strategy}'" + ) + # Create the tool with algorithm configurations gripping_tool = GetGrippingPointTool( connector=connector, @@ -247,14 +251,10 @@ def main( estimator_config=GrippingPointEstimatorConfig(**estimator_config), filter_config=PointCloudFilterConfig(**filter_config), ) - print(f"elapsed time: {time.time() - start_time} seconds") - - # Test the tool directly - print( - f"\nTesting GetGrippingPointTool with object '{test_object}', strategy '{strategy}'" - ) result = gripping_tool._run(test_object) + print(f"elapsed time: {time.time() - start_time} seconds") + gripping_points = extract_gripping_points(result) print(f"\nFound {len(gripping_points)} gripping points in target frame:") @@ -268,11 +268,13 @@ def main( segmented_clouds = gripping_tool.point_cloud_from_segmentation.run( test_object ) + filtered_clouds = gripping_tool.point_cloud_filter.run(segmented_clouds) + print( "\nPublishing debug data to /debug_gripping_points_pointcloud and /debug_gripping_points_markerarray" ) _publish_gripping_point_debug_data( - connector, segmented_clouds, gripping_points, frames["target"] + connector, filtered_clouds, gripping_points, frames["target"] ) print("✅ Debug data published") @@ -327,8 +329,8 @@ def test_gripping_points_maciej_demo(strategy): "distance_threshold_m": 0.008, }, filter_config={ - "strategy": "dbscan", - "dbscan_eps": 0.02, - "dbscan_min_samples": 10, + "strategy": "isolation_forest", + "if_max_samples": "auto", + "if_contamination": 0.05, }, ) From 58d354f829cb570d82c1d37f1c9e6b16880150c3 Mon Sep 17 00:00:00 2001 From: Julia Jia Date: Tue, 23 Sep 2025 11:02:38 -0700 Subject: [PATCH 10/13] Update CI to skip manual tests --- .github/workflows/poetry-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/poetry-test.yml b/.github/workflows/poetry-test.yml index 98ec599e5..ab8b08a4f 100644 --- a/.github/workflows/poetry-test.yml +++ b/.github/workflows/poetry-test.yml @@ -71,4 +71,4 @@ jobs: shell: bash run: | source setup_shell.sh - pytest -m "not billable" + pytest -m "not billable and not manual" From 8713bed3f76aee30b41e7173a42860482c25416f Mon Sep 17 00:00:00 2001 From: Julia Jia Date: Tue, 23 Sep 2025 11:14:29 -0700 Subject: [PATCH 11/13] Bump up rai_core minor version --- src/rai_core/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rai_core/pyproject.toml b/src/rai_core/pyproject.toml index 714ff448a..c7567590d 100644 --- a/src/rai_core/pyproject.toml +++ b/src/rai_core/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "rai_core" -version = "2.5.0" +version = "2.5.1" description = "Core functionality for RAI framework" authors = ["Maciej Majek ", "Bartłomiej Boczek ", "Kajetan Rachwał "] readme = "README.md" From e56cd0cf58d034a71244e346de977ae59817201c Mon Sep 17 00:00:00 2001 From: Julia Jia Date: Fri, 26 Sep 2025 14:32:22 -0700 Subject: [PATCH 12/13] Initial consolidation of new pipeline code to openset_vision --- examples/manipulation-demo-v2.py | 48 +++-- .../rai/tools/ros2/detection/__init__.py | 21 --- .../rai/tools/ros2/detection/tools.py | 168 ------------------ .../rai/tools/ros2/manipulation/__init__.py | 2 + .../rai/tools/ros2/manipulation/custom.py | 72 ++++++-- src/rai_core/rai/tools/timeout.py | 8 +- .../rai_open_set_vision/__init__.py | 34 +++- .../rai_open_set_vision/tools/__init__.py | 20 +++ .../tools/pcl_detection.py} | 29 +-- .../tools/pcl_detection_tools.py | 133 ++++++++++++++ .../tools/segmentation_tools.py | 20 +-- .../test_gripping_points.py | 46 ++--- .../test_pcl_detection_tools.py} | 23 ++- 13 files changed, 306 insertions(+), 318 deletions(-) delete mode 100644 src/rai_core/rai/tools/ros2/detection/__init__.py delete mode 100644 src/rai_core/rai/tools/ros2/detection/tools.py rename src/{rai_core/rai/tools/ros2/detection/pcl.py => rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/pcl_detection.py} (95%) create mode 100644 src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/pcl_detection_tools.py rename tests/{tools/ros2 => rai_extensions}/test_gripping_points.py (86%) rename tests/{tools/ros2/test_detection_tools.py => rai_extensions/test_pcl_detection_tools.py} (90%) diff --git a/examples/manipulation-demo-v2.py b/examples/manipulation-demo-v2.py index f384a82b3..750142b7a 100644 --- a/examples/manipulation-demo-v2.py +++ b/examples/manipulation-demo-v2.py @@ -24,17 +24,18 @@ from rai.agents.langchain.core import create_conversational_agent from rai.communication.ros2 import wait_for_ros2_services, wait_for_ros2_topics from rai.communication.ros2.connectors import ROS2Connector -from rai.tools.ros2.detection.pcl import ( - GrippingPointEstimatorConfig, - PointCloudFilterConfig, - PointCloudFromSegmentationConfig, -) -from rai.tools.ros2.detection.tools import GetGrippingPointTool from rai.tools.ros2.manipulation import ( + GetObjectGrippingPointsTool, MoveObjectFromToTool, ResetArmTool, ) from rai.tools.ros2.simple import GetROS2ImageConfiguredTool +from rai_open_set_vision import ( + GetGrippingPointTool, + GrippingPointEstimatorConfig, + PointCloudFilterConfig, + PointCloudFromSegmentationConfig, +) from rai_whoami.models import EmbodimentInfo @@ -51,21 +52,7 @@ def create_agent(): wait_for_ros2_topics(connector, required_topics) node = connector.node - - # Declare and set parameters for GetGrippingPointTool - # These also can be set in the launch file or during runtime - parameters_to_set = [ - ("conversion_ratio", 1.0), - ("detection_tools.gripping_point.target_frame", "panda_link0"), - ("detection_tools.gripping_point.source_frame", "RGBDCamera5"), - ("detection_tools.gripping_point.camera_topic", "/color_image5"), - ("detection_tools.gripping_point.depth_topic", "/depth_image5"), - ("detection_tools.gripping_point.camera_info_topic", "/color_camera_info5"), - ] - - # Declare and set each parameter (timeout_sec handled by tool internally) - for param_name, param_value in parameters_to_set: - node.declare_parameter(param_name, param_value) + node.declare_parameter("conversion_ratio", 1.0) # Configure gripping point detection algorithms segmentation_config = PointCloudFromSegmentationConfig( @@ -89,12 +76,23 @@ def create_agent(): dbscan_min_samples=10, ) + # Create the underlying GetGrippingPointTool + gripping_point_tool = GetGrippingPointTool( + connector=connector, + segmentation_config=segmentation_config, + estimator_config=estimator_config, + filter_config=filter_config, + ) + tools: List[BaseTool] = [ - GetGrippingPointTool( + GetObjectGrippingPointsTool( connector=connector, - segmentation_config=segmentation_config, - estimator_config=estimator_config, - filter_config=filter_config, + target_frame="panda_link0", + source_frame="RGBDCamera5", + camera_topic="/color_image5", + depth_topic="/depth_image5", + camera_info_topic="/color_camera_info5", + get_gripping_point_tool=gripping_point_tool, ), MoveObjectFromToTool(connector=connector, manipulator_frame="panda_link0"), ResetArmTool(connector=connector, manipulator_frame="panda_link0"), diff --git a/src/rai_core/rai/tools/ros2/detection/__init__.py b/src/rai_core/rai/tools/ros2/detection/__init__.py deleted file mode 100644 index b168382fa..000000000 --- a/src/rai_core/rai/tools/ros2/detection/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (C) 2025 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .tools import ( - GetGrippingPointTool, -) - -__all__ = [ - "GetGrippingPointTool", -] diff --git a/src/rai_core/rai/tools/ros2/detection/tools.py b/src/rai_core/rai/tools/ros2/detection/tools.py deleted file mode 100644 index 04feb7bee..000000000 --- a/src/rai_core/rai/tools/ros2/detection/tools.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright (C) 2025 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Optional, Type - -from pydantic import BaseModel, Field - -from rai.tools.ros2.base import BaseROS2Tool -from rai.tools.ros2.detection.pcl import ( - GrippingPointEstimator, - GrippingPointEstimatorConfig, - PointCloudFilter, - PointCloudFilterConfig, - PointCloudFromSegmentation, - PointCloudFromSegmentationConfig, -) -from rai.tools.timeout import TimeoutError, timeout - - -class GetGrippingPointToolInput(BaseModel): - object_name: str = Field( - ..., - description="The name of the object to get the gripping point of e.g. 'box', 'apple', 'screwdriver'", - ) - - -class GetGrippingPointTool(BaseROS2Tool): - name: str = "get_gripping_point" - description: str = "Get gripping points for specified object/objects. Returns 3D coordinates where a robot gripper can grasp the object." - - # Configuration for PCL components - segmentation_config: PointCloudFromSegmentationConfig - estimator_config: GrippingPointEstimatorConfig - filter_config: PointCloudFilterConfig - - # Auto-initialized in model_post_init from ROS2 parameters - target_frame: Optional[str] = None - source_frame: Optional[str] = None - camera_topic: Optional[str] = None - depth_topic: Optional[str] = None - camera_info_topic: Optional[str] = None - timeout_sec: Optional[float] = None - - # Components initialized in model_post_init - gripping_point_estimator: Optional[GrippingPointEstimator] = None - point_cloud_filter: Optional[PointCloudFilter] = None - point_cloud_from_segmentation: Optional[PointCloudFromSegmentation] = None - - args_schema: Type[GetGrippingPointToolInput] = GetGrippingPointToolInput - - def model_post_init(self, __context: Any) -> None: - """Initialize tool with ROS2 parameters and components.""" - self._load_parameters() - self._initialize_components() - - def _load_parameters(self) -> None: - """Load configuration from ROS2 parameters.""" - node = self.connector.node - param_prefix = "detection_tools.gripping_point" - - # Declare required parameters - required_params = [ - f"{param_prefix}.target_frame", - f"{param_prefix}.source_frame", - f"{param_prefix}.camera_topic", - f"{param_prefix}.depth_topic", - f"{param_prefix}.camera_info_topic", - ] - - for param_name in required_params: - if not node.has_parameter(param_name): - raise ValueError( - f"Required parameter '{param_name}' must be set before initializing GetGrippingPointTool" - ) - - # Optional parameter with default - node.declare_parameter(f"{param_prefix}.timeout_sec", 10.0) - - # Load parameters - self.target_frame = node.get_parameter(f"{param_prefix}.target_frame").value - self.source_frame = node.get_parameter(f"{param_prefix}.source_frame").value - self.camera_topic = node.get_parameter(f"{param_prefix}.camera_topic").value - self.depth_topic = node.get_parameter(f"{param_prefix}.depth_topic").value - self.camera_info_topic = node.get_parameter( - f"{param_prefix}.camera_info_topic" - ).value - self.timeout_sec = node.get_parameter(f"{param_prefix}.timeout_sec").value - - # Validate required parameters are not empty - if not all( - [ - self.target_frame, - self.source_frame, - self.camera_topic, - self.depth_topic, - self.camera_info_topic, - ] - ): - raise ValueError( - "Required ROS2 parameters for GetGrippingPointTool cannot be empty" - ) - - def _initialize_components(self) -> None: - """Initialize PCL components with loaded parameters.""" - self.point_cloud_from_segmentation = PointCloudFromSegmentation( - connector=self.connector, - camera_topic=self.camera_topic, - depth_topic=self.depth_topic, - camera_info_topic=self.camera_info_topic, - source_frame=self.source_frame, - target_frame=self.target_frame, - config=self.segmentation_config, - ) - self.gripping_point_estimator = GrippingPointEstimator( - config=self.estimator_config - ) - self.point_cloud_filter = PointCloudFilter(config=self.filter_config) - - def _run(self, object_name: str) -> str: - # this will be not work in agent scenario because signal need to be run in main thread, comment out for now - @timeout( - self.timeout_sec, - f"Gripping point detection for object '{object_name}' exceeded {self.timeout_sec} seconds", - ) - def _run_with_timeout(): - pcl = self.point_cloud_from_segmentation.run(object_name) - if len(pcl) == 0: - return f"No {object_name}s detected." - - pcl = self.point_cloud_filter.run(pcl) - gps = self.gripping_point_estimator.run(pcl) - - message = "" - if len(gps) == 0: - message += f"No gripping point found for the object {object_name}\n" - elif len(gps) == 1: - message += ( - f"The gripping point of the object {object_name} is {gps[0]}\n" - ) - else: - message += ( - f"Multiple gripping points found for the object {object_name}\n" - ) - - for i, gp in enumerate(gps): - message += ( - f"The gripping point of the object {i + 1} {object_name} is {gp}\n" - ) - - return message - - try: - return _run_with_timeout() - except TimeoutError: - return f"Timeout: Gripping point detection for object '{object_name}' exceeded {self.timeout_sec} seconds" - except Exception: - raise diff --git a/src/rai_core/rai/tools/ros2/manipulation/__init__.py b/src/rai_core/rai/tools/ros2/manipulation/__init__.py index 07f1b16b1..12bf976f1 100644 --- a/src/rai_core/rai/tools/ros2/manipulation/__init__.py +++ b/src/rai_core/rai/tools/ros2/manipulation/__init__.py @@ -20,6 +20,7 @@ ) from .custom import ( + GetObjectGrippingPointsTool, GetObjectPositionsTool, MoveObjectFromToTool, MoveObjectFromToToolInput, @@ -29,6 +30,7 @@ ) __all__ = [ + "GetObjectGrippingPointsTool", "GetObjectPositionsTool", "MoveObjectFromToTool", "MoveObjectFromToToolInput", diff --git a/src/rai_core/rai/tools/ros2/manipulation/custom.py b/src/rai_core/rai/tools/ros2/manipulation/custom.py index b2d56e8bb..e9a385c84 100644 --- a/src/rai_core/rai/tools/ros2/manipulation/custom.py +++ b/src/rai_core/rai/tools/ros2/manipulation/custom.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Literal, Type +from typing import Literal, Type, List import numpy as np from deprecated import deprecated @@ -23,19 +23,13 @@ from rai.communication.ros2.ros_async import get_future_result from rai.tools.ros2.base import BaseROS2Tool +from rai.tools.timeout import RaiTimeoutError try: - from rai_interfaces.srv import ManipulatorMoveTo + from rai_open_set_vision import GetGrabbingPointTool, GetGrippingPointTool except ImportError: logging.warning( - "rai_interfaces is not installed, ManipulatorMoveTo tool will not work." - ) - -try: - from rai_open_set_vision.tools import GetGrabbingPointTool -except ImportError: - logging.warning( - "rai_open_set_vision is not installed, GetGrabbingPointTool will not work" + "rai_open_set_vision is not installed, GetGrabbingPointTool or GetGrippingPointTool may not work" ) @@ -260,7 +254,7 @@ class GetObjectPositionsToolInput(BaseModel): ) -@deprecated("Use GetGrippingPointTool from rai.tools.ros2.detection instead") +@deprecated("Use GetObjectGrippingPointsTool from rai_core.rai.tools.ros2.manipulation.custom instead") class GetObjectPositionsTool(BaseROS2Tool): name: str = "get_object_positions" description: str = ( @@ -312,6 +306,62 @@ def _run(self, object_name: str): return f"Centroids of detected {object_name}s in {self.target_frame} frame: [{', '.join(map(self.format_pose, mani_frame_poses))}]. Sizes of the detected objects are unknown." +class GetObjectGrippingPointsTool(BaseROS2Tool): + name: str = "get_object_gripping_points" + description: str = ( + "Retrieve the gripping points of all objects of a specified type in the target frame. " + "This tool provides accurate gripping point data but does not distinguish between different colors of the same object type. " + "While gripping point detection is reliable, please note that object classification may occasionally be inaccurate." + ) + + target_frame: str + source_frame: str + camera_topic: str # rgb camera topic + depth_topic: str + camera_info_topic: str # rgb camera info topic + + get_gripping_point_tool: "GetGrippingPointTool" + + args_schema: Type[GetObjectPositionsToolInput] = GetObjectPositionsToolInput + + @staticmethod + def format_pose(pose: Pose): + return f"Centroid(x={pose.position.x:.2f}, y={pose.position.y:2f}, z={pose.position.z:2f})" + + def _run(self, object_name: str): + transform = self.connector.get_transform( + target_frame=self.target_frame, source_frame=self.source_frame + ) + + try: + # Get raw gripping points from the tool + gripping_points = self.get_gripping_point_tool._run(object_name) + except RaiTimeoutError as e: + return str(e) + except Exception as e: + return f"Error getting gripping points: {str(e)}" + + if len(gripping_points) == 0: + return f"No {object_name}s detected." + + # Transform gripping points to target frame + mani_frame_poses = [] + for gp in gripping_points: + pose = Pose(position=Point(x=gp[0], y=gp[1], z=gp[2])) + mani_frame_pose = do_transform_pose(pose, transform) + mani_frame_poses.append(mani_frame_pose) + + # Format message similar to original GetGrippingPointTool + if len(mani_frame_poses) == 1: + message = f"The gripping point of the object {object_name} is {mani_frame_poses[0].position.x:.3f}, {mani_frame_poses[0].position.y:.3f}, {mani_frame_poses[0].position.z:.3f}\n" + else: + message = f"Multiple gripping points found for the object {object_name}\n" + for i, pose in enumerate(mani_frame_poses): + message += f"The gripping point of the object {i + 1} {object_name} is {pose.position.x:.3f}, {pose.position.y:.3f}, {pose.position.z:.3f}\n" + + return message + + class ResetArmToolInput(BaseModel): pass diff --git a/src/rai_core/rai/tools/timeout.py b/src/rai_core/rai/tools/timeout.py index 13c788839..662864530 100644 --- a/src/rai_core/rai/tools/timeout.py +++ b/src/rai_core/rai/tools/timeout.py @@ -40,8 +40,8 @@ F = TypeVar("F", bound=Callable[..., Any]) -class TimeoutError(Exception): - """Raised when an operation times out.""" +class RaiTimeoutError(Exception): + """Custom timeout exception for RAI tools""" pass @@ -93,7 +93,7 @@ def wrapper(*args, **kwargs): timeout_message or f"Function '{func.__name__}' timed out after {seconds} seconds" ) - raise TimeoutError(message) + raise RaiTimeoutError(message) return wrapper @@ -139,7 +139,7 @@ def wrapper(self, *args, **kwargs): timeout_message or f"Method '{func.__name__}' of {self.__class__.__name__} timed out after {seconds} seconds" ) - raise TimeoutError(message) + raise RaiTimeoutError(message) return wrapper diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/__init__.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/__init__.py index 32ad003b2..731902a4b 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/__init__.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/__init__.py @@ -12,14 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Service names for ROS2 - defined here to avoid circular imports +GDINO_SERVICE_NAME = "grounding_dino_classify" +GDINO_NODE_NAME = "grounding_dino_node" +GSAM_SERVICE_NAME = "grounded_sam_segment" +GSAM_NODE_NAME = "grounded_sam_node" -from .agents.grounded_sam import GSAM_NODE_NAME, GSAM_SERVICE_NAME, GroundedSamAgent -from .agents.grounding_dino import ( - GDINO_NODE_NAME, - GDINO_SERVICE_NAME, - GroundingDinoAgent, +from .agents import GroundedSamAgent, GroundingDinoAgent # noqa: E402 +from .tools import GetDetectionTool, GetDistanceToObjectsTool # noqa: E402 +from .tools.pcl_detection import ( # noqa: E402 + GrippingPointEstimator, + GrippingPointEstimatorConfig, + PointCloudFilter, + PointCloudFilterConfig, + PointCloudFromSegmentation, + PointCloudFromSegmentationConfig, + depth_to_point_cloud, +) +from .tools.pcl_detection_tools import ( # noqa: E402 + GetGrippingPointTool, + GetGrippingPointToolInput, ) -from .tools import GetDetectionTool, GetDistanceToObjectsTool __all__ = [ "GDINO_NODE_NAME", @@ -28,6 +41,15 @@ "GSAM_SERVICE_NAME", "GetDetectionTool", "GetDistanceToObjectsTool", + "GetGrippingPointTool", + "GetGrippingPointToolInput", + "GrippingPointEstimator", + "GrippingPointEstimatorConfig", "GroundedSamAgent", "GroundingDinoAgent", + "PointCloudFilter", + "PointCloudFilterConfig", + "PointCloudFromSegmentation", + "PointCloudFromSegmentationConfig", + "depth_to_point_cloud", ] diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/__init__.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/__init__.py index 916b3ef45..8ec4bc755 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/__init__.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/__init__.py @@ -13,6 +13,16 @@ # limitations under the License. from .gdino_tools import DistanceMeasurement, GetDetectionTool, GetDistanceToObjectsTool +from .pcl_detection import ( + GrippingPointEstimator, + GrippingPointEstimatorConfig, + PointCloudFilter, + PointCloudFilterConfig, + PointCloudFromSegmentation, + PointCloudFromSegmentationConfig, + depth_to_point_cloud, +) +from .pcl_detection_tools import GetGrippingPointTool, GetGrippingPointToolInput from .segmentation_tools import GetGrabbingPointTool, GetSegmentationTool __all__ = [ @@ -20,5 +30,15 @@ "GetDetectionTool", "GetDistanceToObjectsTool", "GetGrabbingPointTool", + "GetGrippingPointTool", + "GetGrippingPointToolInput", "GetSegmentationTool", + # PCL Detection APIs + "GrippingPointEstimator", + "GrippingPointEstimatorConfig", + "PointCloudFilter", + "PointCloudFilterConfig", + "PointCloudFromSegmentation", + "PointCloudFromSegmentationConfig", + "depth_to_point_cloud", ] diff --git a/src/rai_core/rai/tools/ros2/detection/pcl.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/pcl_detection.py similarity index 95% rename from src/rai_core/rai/tools/ros2/detection/pcl.py rename to src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/pcl_detection.py index 9c671561b..a5695cb27 100644 --- a/src/rai_core/rai/tools/ros2/detection/pcl.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/pcl_detection.py @@ -18,19 +18,15 @@ import sensor_msgs.msg from numpy.typing import NDArray from pydantic import BaseModel -from rai_open_set_vision import GDINO_SERVICE_NAME -from rclpy import Future -from rclpy.exceptions import ( - ParameterNotDeclaredException, - ParameterUninitializedException, -) - from rai.communication.ros2.api import ( convert_ros_img_to_ndarray, # type: ignore[reportUnknownVariableType] ) from rai.communication.ros2.connectors import ROS2Connector from rai.communication.ros2.ros_async import get_future_result +from rclpy import Future + from rai_interfaces.srv import RAIGroundedSam, RAIGroundingDino +from rai_open_set_vision import GDINO_SERVICE_NAME class PointCloudFromSegmentationConfig(BaseModel): @@ -168,6 +164,7 @@ def __init__( source_frame: str, target_frame: str, config: PointCloudFromSegmentationConfig, + conversion_ratio: float = 0.001, ) -> None: self.connector = connector self.camera_topic = camera_topic @@ -176,6 +173,7 @@ def __init__( self.source_frame = source_frame self.target_frame = target_frame self.config = config + self.conversion_ratio = conversion_ratio # --------------------- ROS helpers --------------------- def _get_image_message(self, topic: str) -> sensor_msgs.msg.Image: @@ -286,22 +284,7 @@ def run(self, object_name: str) -> list[NDArray[np.float32]]: gdino_future = self._call_gdino_node(camera_img_msg, object_name) - logger = self.connector.node.get_logger() - try: - conversion_ratio_value = self.connector.node.get_parameter( - "conversion_ratio" - ).value # type: ignore[reportUnknownMemberType] - conversion_ratio: float - if isinstance(conversion_ratio_value, float): - conversion_ratio = conversion_ratio_value - else: - logger.error( # type: ignore[reportUnknownMemberType] - "Parameter conversion_ratio has wrong type. Using default 0.001" - ) - conversion_ratio = 0.001 - except (ParameterUninitializedException, ParameterNotDeclaredException): - logger.warning("Parameter conversion_ratio not found. Using default 0.001") # type: ignore[reportUnknownMemberType] - conversion_ratio = 0.001 + conversion_ratio = self.conversion_ratio gdino_resolved = get_future_result(gdino_future) if gdino_resolved is None: diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/pcl_detection_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/pcl_detection_tools.py new file mode 100644 index 000000000..3ab709339 --- /dev/null +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/pcl_detection_tools.py @@ -0,0 +1,133 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Optional, Type + +import numpy as np +from pydantic import BaseModel, Field +from rai.tools.ros2.base import BaseROS2Tool +from rai.tools.timeout import RaiTimeoutError, timeout + +from .pcl_detection import ( + GrippingPointEstimator, + GrippingPointEstimatorConfig, + PointCloudFilter, + PointCloudFilterConfig, + PointCloudFromSegmentation, + PointCloudFromSegmentationConfig, +) + + +class GetGrippingPointToolInput(BaseModel): + object_name: str = Field( + ..., + description="The name of the object to get the gripping point of e.g. 'box', 'apple', 'screwdriver'", + ) + timeout_sec: Optional[float] = Field( + default=None, + description="Override timeout in seconds. If not provided, uses tool's default timeout.", + ) + conversion_ratio: Optional[float] = Field( + default=None, + description="Override conversion ratio for depth to meters. If not provided, uses tool's default.", + ) + + +class GetGrippingPointTool(BaseROS2Tool): + name: str = "get_gripping_point" + description: str = "Get gripping points for specified object/objects. Returns 3D coordinates where a robot gripper can grasp the object." + + # Configuration for PCL components + segmentation_config: PointCloudFromSegmentationConfig + estimator_config: GrippingPointEstimatorConfig + filter_config: PointCloudFilterConfig + + # Required parameters + target_frame: str + source_frame: str + camera_topic: str + depth_topic: str + camera_info_topic: str + timeout_sec: float = 10.0 # Default timeout + conversion_ratio: float = 0.001 # Default conversion ratio + + # Components initialized in model_post_init + gripping_point_estimator: Optional[GrippingPointEstimator] = None + point_cloud_filter: Optional[PointCloudFilter] = None + point_cloud_from_segmentation: Optional[PointCloudFromSegmentation] = None + + args_schema: Type[GetGrippingPointToolInput] = GetGrippingPointToolInput + + def model_post_init(self, __context: Any) -> None: + """Initialize tool components.""" + self._initialize_components() + + def _initialize_components(self) -> None: + """Initialize PCL components with provided parameters.""" + self.point_cloud_from_segmentation = PointCloudFromSegmentation( + connector=self.connector, + camera_topic=self.camera_topic, + depth_topic=self.depth_topic, + camera_info_topic=self.camera_info_topic, + source_frame=self.source_frame, + target_frame=self.target_frame, + config=self.segmentation_config, + conversion_ratio=self.conversion_ratio, + ) + self.gripping_point_estimator = GrippingPointEstimator( + config=self.estimator_config + ) + self.point_cloud_filter = PointCloudFilter(config=self.filter_config) + + def _run( + self, + object_name: str, + timeout_sec: Optional[float] = None, + conversion_ratio: Optional[float] = None, + ) -> List[np.ndarray]: + """Run gripping point detection and return raw gripping points.""" + + # Use runtime parameters if provided, otherwise use defaults + effective_timeout = timeout_sec if timeout_sec is not None else self.timeout_sec + effective_conversion_ratio = ( + conversion_ratio if conversion_ratio is not None else self.conversion_ratio + ) + + # Update conversion ratio if different from current + if effective_conversion_ratio != self.conversion_ratio: + self.point_cloud_from_segmentation.conversion_ratio = ( + effective_conversion_ratio + ) + + @timeout( + effective_timeout, + f"Gripping point detection for object '{object_name}' exceeded {effective_timeout} seconds", + ) + def _run_with_timeout(): + pcl = self.point_cloud_from_segmentation.run(object_name) + if len(pcl) == 0: + return [] + + pcl = self.point_cloud_filter.run(pcl) + gps = self.gripping_point_estimator.run(pcl) + return gps + + try: + return _run_with_timeout() + except RaiTimeoutError as e: + # Log the timeout but still raise it + self.connector.node.get_logger().warning(f"Timeout: {e}") + raise # Let caller decide how to handle + except Exception: + raise diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py index b168943c0..16c6fc2df 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py @@ -291,26 +291,8 @@ def _process_mask( masked_depth_image, intrinsic[0], intrinsic[1], intrinsic[2], intrinsic[3] ) + # TODO: Filter out outliers points = pcd - # publish resulting pointcloud - # TODO(juliajia): remove this after debugging - # import time - - # from geometry_msgs.msg import Point32 - # from sensor_msgs.msg import PointCloud - - # msg = PointCloud() - # msg.header.frame_id = "egofront_rgbd_camera_depth_optical_frame" - # msg.points = [Point32(x=p[0], y=p[1], z=p[2]) for p in points] - # pub = self.connector.node.create_publisher( - # PointCloud, "/debug/get_grabbing_point_pointcloud", 10 - # ) - # while True: - # self.connector.node.get_logger().info( - # f"publishing pointcloud to /debug/get_grabbing_point_pointcloud: {len(msg.points)} points, mean: {np.array(points.mean(axis=0))}." - # ) - # pub.publish(msg) - # time.sleep(0.1) # https://github.com/ycheng517/tabletop-handybot/blob/6d401e577e41ea86529d091b406fbfc936f37a8d/tabletop_handybot/tabletop_handybot/tabletop_handybot_node.py#L413-L424 grasp_z = points[:, 2].max() diff --git a/tests/tools/ros2/test_gripping_points.py b/tests/rai_extensions/test_gripping_points.py similarity index 86% rename from tests/tools/ros2/test_gripping_points.py rename to tests/rai_extensions/test_gripping_points.py index a3dad0bf6..d2e77e9a7 100644 --- a/tests/tools/ros2/test_gripping_points.py +++ b/tests/rai_extensions/test_gripping_points.py @@ -26,6 +26,7 @@ """ import time +from typing import List import cv2 import numpy as np @@ -35,14 +36,16 @@ from cv_bridge import CvBridge from rai.communication.ros2 import wait_for_ros2_services, wait_for_ros2_topics from rai.communication.ros2.connectors import ROS2Connector -from rai.tools.ros2.detection import GetGrippingPointTool -from rai.tools.ros2.detection.pcl import ( +from rai_open_set_vision import ( + GetGrippingPointTool, GrippingPointEstimatorConfig, PointCloudFilterConfig, PointCloudFromSegmentationConfig, - _publish_gripping_point_debug_data, ) +# import internal tool +from rai_open_set_vision.tools.pcl_detection import _publish_gripping_point_debug_data + def draw_points_on_image(image_msg, points, camera_info): """Draw points on the camera image.""" @@ -86,20 +89,9 @@ def draw_points_on_image(image_msg, points, camera_info): return cv_image -def extract_gripping_points(result: str) -> list[np.ndarray]: - """Extract gripping points from the result.""" - gripping_points = [] - lines = result.split("\n") - for line in lines: - if "gripping point" in line and "is [" in line: - # Extract coordinates from line like "is [0.39972728 0.16179778 0.04179673]" - start = line.find("[") + 1 - end = line.find("]") - if start > 0 and end > start: - coords_str = line[start:end] - coords = [float(x) for x in coords_str.split()] - gripping_points.append(np.array(coords)) - return gripping_points +def extract_gripping_points(result: List[np.ndarray]) -> list[np.ndarray]: + """Extract gripping points from the result - now returns raw points directly.""" + return result def transform_points_to_target_frame(connector, points, source_frame, target_frame): @@ -221,22 +213,9 @@ def main( wait_for_ros2_topics(connector, list(topics.values())) print("✅ All services and topics available") - # Set up node parameters + # Set up conversion ratio parameter node = connector.node - - # Declare and set ROS2 parameters for deployment configuration - parameters_to_set = [ - ("conversion_ratio", 1.0), - ("detection_tools.gripping_point.target_frame", frames["target"]), - ("detection_tools.gripping_point.source_frame", frames["source"]), - ("detection_tools.gripping_point.camera_topic", topics["camera"]), - ("detection_tools.gripping_point.depth_topic", topics["depth"]), - ("detection_tools.gripping_point.camera_info_topic", topics["camera_info"]), - ] - - # Declare and set each parameter - for param_name, param_value in parameters_to_set: - node.declare_parameter(param_name, param_value) + node.declare_parameter("conversion_ratio", 1.0) start_time = time.time() @@ -255,7 +234,8 @@ def main( result = gripping_tool._run(test_object) print(f"elapsed time: {time.time() - start_time} seconds") - gripping_points = extract_gripping_points(result) + # result is now a list of numpy arrays directly + gripping_points = result print(f"\nFound {len(gripping_points)} gripping points in target frame:") for i, gp in enumerate(gripping_points): diff --git a/tests/tools/ros2/test_detection_tools.py b/tests/rai_extensions/test_pcl_detection_tools.py similarity index 90% rename from tests/tools/ros2/test_detection_tools.py rename to tests/rai_extensions/test_pcl_detection_tools.py index 914abfdc9..7c35eac91 100644 --- a/tests/tools/ros2/test_detection_tools.py +++ b/tests/rai_extensions/test_pcl_detection_tools.py @@ -27,8 +27,9 @@ import numpy as np from rai.communication.ros2.connectors import ROS2Connector -from rai.tools.ros2.detection import GetGrippingPointTool -from rai.tools.ros2.detection.pcl import ( +from rai.tools.timeout import RaiTimeoutError +from rai_open_set_vision import ( + GetGrippingPointTool, GrippingPointEstimator, GrippingPointEstimatorConfig, PointCloudFilter, @@ -177,6 +178,11 @@ def test_get_gripping_point_tool_timeout(): tool = GetGrippingPointTool( connector=mock_connector, + target_frame="panda_link0", + source_frame="RGBDCamera5", + camera_topic="/color_image5", + depth_topic="/depth_image5", + camera_info_topic="/color_camera_info5", segmentation_config=PointCloudFromSegmentationConfig(), estimator_config=GrippingPointEstimatorConfig(), filter_config=PointCloudFilterConfig(), @@ -185,14 +191,14 @@ def test_get_gripping_point_tool_timeout(): tool.gripping_point_estimator = mock_estimator tool.point_cloud_filter = mock_filter tool.timeout_sec = 5.0 - tool.point_cloud_from_segmentation = mock_pcl_gen # Connect the mock + tool.point_cloud_from_segmentation = mock_pcl_gen # Test fast execution - should complete without timeout result = tool._run("test_object") - assert "No test_objects detected" in result - assert "timed out" not in result.lower() + assert result == [] # Returns empty list for no objects found + assert len(result) == 0 - # Test 2: Actual timeout behavior + # Test 2: Actual timeout behavior - should raise TimeoutError def slow_operation(obj_name): time.sleep(2.0) # Longer than timeout return [] @@ -200,5 +206,6 @@ def slow_operation(obj_name): mock_pcl_gen.run.side_effect = slow_operation tool.timeout_sec = 1.0 # Short timeout - result = tool._run("test") - assert "timed out" in result.lower() or "timeout" in result.lower() + # Expect TimeoutError to be raised + with pytest.raises(RaiTimeoutError, match="exceeded 1.0 seconds"): + tool._run("test") From 137729491d34ddb4c2f78b813fe0ac1dadf71d0b Mon Sep 17 00:00:00 2001 From: Julia Jia Date: Sat, 27 Sep 2025 17:05:13 -0700 Subject: [PATCH 13/13] Merge 3d detection pipeline code to rai_open_set_vision tools --- examples/manipulation-demo-v2.py | 85 ++++---- .../rai/tools/ros2/manipulation/__init__.py | 2 - .../rai/tools/ros2/manipulation/custom.py | 72 ++----- .../rai_open_set_vision/__init__.py | 8 +- .../rai_open_set_vision/tools/__init__.py | 9 +- .../tools/pcl_detection.py | 96 ++++++--- .../tools/pcl_detection_tools.py | 186 ++++++++++++------ tests/rai_extensions/test_gripping_points.py | 89 +++++---- .../test_pcl_detection_tools.py | 20 +- xd.py | 31 --- 10 files changed, 327 insertions(+), 271 deletions(-) delete mode 100644 xd.py diff --git a/examples/manipulation-demo-v2.py b/examples/manipulation-demo-v2.py index 750142b7a..577df3a51 100644 --- a/examples/manipulation-demo-v2.py +++ b/examples/manipulation-demo-v2.py @@ -1,4 +1,4 @@ -# Copyright (C) 2024 Robotec.AI +# Copyright (C) 2025 Julia Jia # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,7 +17,6 @@ from typing import List import rclpy -import rclpy.qos from langchain_core.messages import BaseMessage, HumanMessage from langchain_core.tools import BaseTool from rai import get_llm_model @@ -25,13 +24,12 @@ from rai.communication.ros2 import wait_for_ros2_services, wait_for_ros2_topics from rai.communication.ros2.connectors import ROS2Connector from rai.tools.ros2.manipulation import ( - GetObjectGrippingPointsTool, MoveObjectFromToTool, ResetArmTool, ) from rai.tools.ros2.simple import GetROS2ImageConfiguredTool from rai_open_set_vision import ( - GetGrippingPointTool, + GetObjectGrippingPointsTool, GrippingPointEstimatorConfig, PointCloudFilterConfig, PointCloudFromSegmentationConfig, @@ -40,19 +38,26 @@ from rai_whoami.models import EmbodimentInfo logger = logging.getLogger(__name__) +param_prefix = "pcl.detection.gripping_points" -def create_agent(): - rclpy.init() - connector = ROS2Connector(executor_type="single_threaded") +def initialize_tools(connector: ROS2Connector) -> List[BaseTool]: + """Initialize and configure all tools for the manipulation agent.""" + node = connector.node - required_services = ["/grounded_sam_segment", "/grounding_dino_classify"] - required_topics = ["/color_image5", "/depth_image5", "/color_camera_info5"] - wait_for_ros2_services(connector, required_services) - wait_for_ros2_topics(connector, required_topics) + # Parameters for GetObjectGrippingPointsTool, these also can be set in the launch file or load from yaml file + parameters_to_set = [ + (f"{param_prefix}.target_frame", "panda_link0"), + (f"{param_prefix}.source_frame", "RGBDCamera5"), + (f"{param_prefix}.camera_topic", "/color_image5"), + (f"{param_prefix}.depth_topic", "/depth_image5"), + (f"{param_prefix}.camera_info_topic", "/color_camera_info5"), + (f"{param_prefix}.timeout_sec", 10.0), + (f"{param_prefix}.conversion_ratio", 1.0), + ] - node = connector.node - node.declare_parameter("conversion_ratio", 1.0) + for param_name, param_value in parameters_to_set: + node.declare_parameter(param_name, param_value) # Configure gripping point detection algorithms segmentation_config = PointCloudFromSegmentationConfig( @@ -61,7 +66,7 @@ def create_agent(): ) estimator_config = GrippingPointEstimatorConfig( - strategy="biggest_plane", # Options: "centroid", "top_plane", "biggest_plane" + strategy="centroid", # Options: "centroid", "top_plane", "biggest_plane" top_percentile=0.05, plane_bin_size_m=0.01, ransac_iterations=200, @@ -70,35 +75,49 @@ def create_agent(): ) filter_config = PointCloudFilterConfig( - strategy="dbscan", + strategy="isolation_forest", # Options: "dbscan", "kmeans_largest_cluster", "isolation_forest", "lof" + if_max_samples="auto", + if_contamination=0.05, min_points=20, - dbscan_eps=0.02, - dbscan_min_samples=10, ) - # Create the underlying GetGrippingPointTool - gripping_point_tool = GetGrippingPointTool( - connector=connector, - segmentation_config=segmentation_config, - estimator_config=estimator_config, - filter_config=filter_config, - ) + manipulator_frame = node.get_parameter(f"{param_prefix}.target_frame").value + camera_topic = node.get_parameter(f"{param_prefix}.camera_topic").value tools: List[BaseTool] = [ GetObjectGrippingPointsTool( connector=connector, - target_frame="panda_link0", - source_frame="RGBDCamera5", - camera_topic="/color_image5", - depth_topic="/depth_image5", - camera_info_topic="/color_camera_info5", - get_gripping_point_tool=gripping_point_tool, + segmentation_config=segmentation_config, + estimator_config=estimator_config, + filter_config=filter_config, ), - MoveObjectFromToTool(connector=connector, manipulator_frame="panda_link0"), - ResetArmTool(connector=connector, manipulator_frame="panda_link0"), - GetROS2ImageConfiguredTool(connector=connector, topic="/color_image5"), + MoveObjectFromToTool(connector=connector, manipulator_frame=manipulator_frame), + ResetArmTool(connector=connector, manipulator_frame=manipulator_frame), + GetROS2ImageConfiguredTool(connector=connector, topic=camera_topic), + ] + + return tools + + +def wait_for_ros2_services_and_topics(connector: ROS2Connector): + required_services = ["/grounded_sam_segment", "/grounding_dino_classify"] + required_topics = [ + connector.node.get_parameter(f"{param_prefix}.camera_topic").value, + connector.node.get_parameter(f"{param_prefix}.depth_topic").value, + connector.node.get_parameter(f"{param_prefix}.camera_info_topic").value, ] + wait_for_ros2_services(connector, required_services) + wait_for_ros2_topics(connector, required_topics) + + +def create_agent(): + rclpy.init() + connector = ROS2Connector(executor_type="single_threaded") + + tools = initialize_tools(connector) + wait_for_ros2_services_and_topics(connector) + llm = get_llm_model(model_type="complex_model", streaming=True) embodiment_info = EmbodimentInfo.from_file( "examples/embodiments/manipulation_embodiment.json" diff --git a/src/rai_core/rai/tools/ros2/manipulation/__init__.py b/src/rai_core/rai/tools/ros2/manipulation/__init__.py index 12bf976f1..07f1b16b1 100644 --- a/src/rai_core/rai/tools/ros2/manipulation/__init__.py +++ b/src/rai_core/rai/tools/ros2/manipulation/__init__.py @@ -20,7 +20,6 @@ ) from .custom import ( - GetObjectGrippingPointsTool, GetObjectPositionsTool, MoveObjectFromToTool, MoveObjectFromToToolInput, @@ -30,7 +29,6 @@ ) __all__ = [ - "GetObjectGrippingPointsTool", "GetObjectPositionsTool", "MoveObjectFromToTool", "MoveObjectFromToToolInput", diff --git a/src/rai_core/rai/tools/ros2/manipulation/custom.py b/src/rai_core/rai/tools/ros2/manipulation/custom.py index e9a385c84..e1ecb4d1a 100644 --- a/src/rai_core/rai/tools/ros2/manipulation/custom.py +++ b/src/rai_core/rai/tools/ros2/manipulation/custom.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Literal, Type, List +from typing import Literal, Type import numpy as np from deprecated import deprecated @@ -23,13 +23,19 @@ from rai.communication.ros2.ros_async import get_future_result from rai.tools.ros2.base import BaseROS2Tool -from rai.tools.timeout import RaiTimeoutError try: - from rai_open_set_vision import GetGrabbingPointTool, GetGrippingPointTool + from rai_interfaces.srv import ManipulatorMoveTo except ImportError: logging.warning( - "rai_open_set_vision is not installed, GetGrabbingPointTool or GetGrippingPointTool may not work" + "rai_interfaces is not installed, ManipulatorMoveTo tool will not work." + ) + +try: + from rai_open_set_vision.tools import GetGrabbingPointTool +except ImportError: + logging.warning( + "rai_open_set_vision is not installed, GetGrabbingPointTool will not work" ) @@ -254,7 +260,7 @@ class GetObjectPositionsToolInput(BaseModel): ) -@deprecated("Use GetObjectGrippingPointsTool from rai_core.rai.tools.ros2.manipulation.custom instead") +@deprecated("Use GetObjectGrippingPointsTool from rai_open_set_vision instead") class GetObjectPositionsTool(BaseROS2Tool): name: str = "get_object_positions" description: str = ( @@ -306,62 +312,6 @@ def _run(self, object_name: str): return f"Centroids of detected {object_name}s in {self.target_frame} frame: [{', '.join(map(self.format_pose, mani_frame_poses))}]. Sizes of the detected objects are unknown." -class GetObjectGrippingPointsTool(BaseROS2Tool): - name: str = "get_object_gripping_points" - description: str = ( - "Retrieve the gripping points of all objects of a specified type in the target frame. " - "This tool provides accurate gripping point data but does not distinguish between different colors of the same object type. " - "While gripping point detection is reliable, please note that object classification may occasionally be inaccurate." - ) - - target_frame: str - source_frame: str - camera_topic: str # rgb camera topic - depth_topic: str - camera_info_topic: str # rgb camera info topic - - get_gripping_point_tool: "GetGrippingPointTool" - - args_schema: Type[GetObjectPositionsToolInput] = GetObjectPositionsToolInput - - @staticmethod - def format_pose(pose: Pose): - return f"Centroid(x={pose.position.x:.2f}, y={pose.position.y:2f}, z={pose.position.z:2f})" - - def _run(self, object_name: str): - transform = self.connector.get_transform( - target_frame=self.target_frame, source_frame=self.source_frame - ) - - try: - # Get raw gripping points from the tool - gripping_points = self.get_gripping_point_tool._run(object_name) - except RaiTimeoutError as e: - return str(e) - except Exception as e: - return f"Error getting gripping points: {str(e)}" - - if len(gripping_points) == 0: - return f"No {object_name}s detected." - - # Transform gripping points to target frame - mani_frame_poses = [] - for gp in gripping_points: - pose = Pose(position=Point(x=gp[0], y=gp[1], z=gp[2])) - mani_frame_pose = do_transform_pose(pose, transform) - mani_frame_poses.append(mani_frame_pose) - - # Format message similar to original GetGrippingPointTool - if len(mani_frame_poses) == 1: - message = f"The gripping point of the object {object_name} is {mani_frame_poses[0].position.x:.3f}, {mani_frame_poses[0].position.y:.3f}, {mani_frame_poses[0].position.z:.3f}\n" - else: - message = f"Multiple gripping points found for the object {object_name}\n" - for i, pose in enumerate(mani_frame_poses): - message += f"The gripping point of the object {i + 1} {object_name} is {pose.position.x:.3f}, {pose.position.y:.3f}, {pose.position.z:.3f}\n" - - return message - - class ResetArmToolInput(BaseModel): pass diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/__init__.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/__init__.py index 731902a4b..54e38c64b 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/__init__.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/__init__.py @@ -30,8 +30,8 @@ depth_to_point_cloud, ) from .tools.pcl_detection_tools import ( # noqa: E402 - GetGrippingPointTool, - GetGrippingPointToolInput, + GetObjectGrippingPointsTool, + GetObjectGrippingPointsToolInput, ) __all__ = [ @@ -41,8 +41,8 @@ "GSAM_SERVICE_NAME", "GetDetectionTool", "GetDistanceToObjectsTool", - "GetGrippingPointTool", - "GetGrippingPointToolInput", + "GetObjectGrippingPointsTool", + "GetObjectGrippingPointsToolInput", "GrippingPointEstimator", "GrippingPointEstimatorConfig", "GroundedSamAgent", diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/__init__.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/__init__.py index 8ec4bc755..705a6089e 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/__init__.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/__init__.py @@ -22,7 +22,10 @@ PointCloudFromSegmentationConfig, depth_to_point_cloud, ) -from .pcl_detection_tools import GetGrippingPointTool, GetGrippingPointToolInput +from .pcl_detection_tools import ( + GetObjectGrippingPointsTool, + GetObjectGrippingPointsToolInput, +) from .segmentation_tools import GetGrabbingPointTool, GetSegmentationTool __all__ = [ @@ -30,8 +33,8 @@ "GetDetectionTool", "GetDistanceToObjectsTool", "GetGrabbingPointTool", - "GetGrippingPointTool", - "GetGrippingPointToolInput", + "GetObjectGrippingPointsTool", + "GetObjectGrippingPointsToolInput", "GetSegmentationTool", # PCL Detection APIs "GrippingPointEstimator", diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/pcl_detection.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/pcl_detection.py index a5695cb27..7333dee0a 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/pcl_detection.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/pcl_detection.py @@ -17,7 +17,7 @@ import numpy as np import sensor_msgs.msg from numpy.typing import NDArray -from pydantic import BaseModel +from pydantic import BaseModel, Field from rai.communication.ros2.api import ( convert_ros_img_to_ndarray, # type: ignore[reportUnknownVariableType] ) @@ -30,35 +30,71 @@ class PointCloudFromSegmentationConfig(BaseModel): - box_threshold: float = 0.35 - text_threshold: float = 0.45 + box_threshold: float = Field( + default=0.35, description="Box threshold for GDINO object detection" + ) + text_threshold: float = Field( + default=0.45, description="Text threshold for GDINO object detection" + ) class GrippingPointEstimatorConfig(BaseModel): - strategy: Literal["centroid", "top_plane", "biggest_plane"] = "centroid" - top_percentile: float = 0.05 - plane_bin_size_m: float = 0.01 - ransac_iterations: int = 200 - distance_threshold_m: float = 0.01 - min_points: int = 10 + strategy: Literal["centroid", "top_plane", "biggest_plane"] = Field( + default="centroid", + description="Strategy for estimating gripping points from point clouds", + ) + top_percentile: float = Field( + default=0.05, + description="Fraction of highest Z points to consider (0.05 = top 5%)", + ) + plane_bin_size_m: float = Field( + default=0.01, description="Bin size in meters for plane detection" + ) + ransac_iterations: int = Field( + default=200, description="Number of RANSAC iterations for plane fitting" + ) + distance_threshold_m: float = Field( + default=0.01, + description="Distance threshold in meters for RANSAC plane fitting", + ) + min_points: int = Field( + default=10, description="Minimum number of points required for processing" + ) class PointCloudFilterConfig(BaseModel): strategy: Literal["dbscan", "kmeans_largest_cluster", "isolation_forest", "lof"] = ( - "isolation_forest" + Field( + default="isolation_forest", + description="Clustering strategy for filtering point cloud outliers", + ) + ) + min_points: int = Field( + default=20, description="Minimum number of points required for filtering" ) - min_points: int = 20 # DBSCAN - dbscan_eps: float = 0.02 - dbscan_min_samples: int = 10 + dbscan_eps: float = Field( + default=0.02, description="DBSCAN epsilon parameter for neighborhood radius" + ) + dbscan_min_samples: int = Field( + default=10, description="DBSCAN minimum samples in neighborhood" + ) # KMeans - kmeans_k: int = 2 + kmeans_k: int = Field(default=2, description="Number of clusters for KMeans") # Isolation Forest - if_max_samples: int | float | Literal["auto"] = "auto" - if_contamination: float = 0.05 + if_max_samples: int | float | Literal["auto"] = Field( + default="auto", description="Maximum samples for Isolation Forest" + ) + if_contamination: float = Field( + default=0.05, description="Contamination rate for Isolation Forest" + ) # LOF - lof_n_neighbors: int = 20 - lof_contamination: float = 0.05 + lof_n_neighbors: int = Field( + default=20, description="Number of neighbors for Local Outlier Factor" + ) + lof_contamination: float = Field( + default=0.05, description="Contamination rate for Local Outlier Factor" + ) def depth_to_point_cloud( @@ -81,9 +117,9 @@ def _publish_gripping_point_debug_data( obj_points_xyz: NDArray[np.float32], gripping_points_xyz: list[NDArray[np.float32]], base_frame_id: str = "egoarm_base_link", - publish_duration: float = 10.0, + publish_duration: float = 5.0, ) -> None: - """Publish the gripping point debug data for visualization in RVIZ via point cloud and marker array. + """Publish the gripping point debug data to ROS2 topics which can be visualized in RVIZ. Args: connector: The ROS2 connector. @@ -98,6 +134,14 @@ def _publish_gripping_point_debug_data( from std_msgs.msg import Header from visualization_msgs.msg import Marker, MarkerArray + debug_gripping_points_pointcloud_topic = "/debug_gripping_points_pointcloud" + debug_gripping_points_markerarray_topic = "/debug_gripping_points_markerarray" + + connector.node.get_logger().warning( + "Debug data publishing adds computational overhead and network traffic and impact the performance - not suitable for production. " + f"Data will be published to {debug_gripping_points_pointcloud_topic} and {debug_gripping_points_markerarray_topic} for {publish_duration} seconds." + ) + points = ( np.concatenate(obj_points_xyz, axis=0) if obj_points_xyz @@ -108,11 +152,11 @@ def _publish_gripping_point_debug_data( msg.header.frame_id = base_frame_id # type: ignore[reportUnknownMemberType] msg.points = [Point32(x=float(p[0]), y=float(p[1]), z=float(p[2])) for p in points] # type: ignore[reportUnknownArgumentType] pub = connector.node.create_publisher( # type: ignore[reportUnknownMemberType] - PointCloud, "/debug_gripping_points_pointcloud", 10 + PointCloud, debug_gripping_points_pointcloud_topic, 10 ) marker_pub = connector.node.create_publisher( # type: ignore[reportUnknownMemberType] - MarkerArray, "/debug_gripping_points_markerarray", 10 + MarkerArray, debug_gripping_points_markerarray_topic, 10 ) marker_array = MarkerArray() header = Header() @@ -132,8 +176,6 @@ def _publish_gripping_point_debug_data( m.color.b = 0.0 # type: ignore[reportUnknownMemberType] m.color.a = 1.0 # type: ignore[reportUnknownMemberType] - # m.ns = str(i) - markers.append(m) # type: ignore[reportUnknownArgumentType] marker_array.markers = markers @@ -163,8 +205,8 @@ def __init__( camera_info_topic: str, source_frame: str, target_frame: str, - config: PointCloudFromSegmentationConfig, conversion_ratio: float = 0.001, + config: PointCloudFromSegmentationConfig, ) -> None: self.connector = connector self.camera_topic = camera_topic @@ -284,8 +326,6 @@ def run(self, object_name: str) -> list[NDArray[np.float32]]: gdino_future = self._call_gdino_node(camera_img_msg, object_name) - conversion_ratio = self.conversion_ratio - gdino_resolved = get_future_result(gdino_future) if gdino_resolved is None: return [] @@ -304,7 +344,7 @@ def run(self, object_name: str) -> list[NDArray[np.float32]]: depth, dtype=np.float32 ) masked_depth_image[binary_mask] = depth[binary_mask] - masked_depth_image = masked_depth_image * float(conversion_ratio) + masked_depth_image = masked_depth_image * float(self.conversion_ratio) points_camera: NDArray[np.float32] = depth_to_point_cloud( masked_depth_image, fx, fy, cx, cy diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/pcl_detection_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/pcl_detection_tools.py index 3ab709339..70bb4975a 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/pcl_detection_tools.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/pcl_detection_tools.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Type +from typing import Any, Optional, Type -import numpy as np from pydantic import BaseModel, Field from rai.tools.ros2.base import BaseROS2Tool from rai.tools.timeout import RaiTimeoutError, timeout @@ -28,53 +27,121 @@ PointCloudFromSegmentationConfig, ) +# Parameter prefix for ROS2 configuration +PCL_DETECTION_PARAM_PREFIX = "pcl.detection.gripping_points" -class GetGrippingPointToolInput(BaseModel): + +class GetObjectGrippingPointsToolInput(BaseModel): object_name: str = Field( ..., description="The name of the object to get the gripping point of e.g. 'box', 'apple', 'screwdriver'", ) - timeout_sec: Optional[float] = Field( - default=None, - description="Override timeout in seconds. If not provided, uses tool's default timeout.", - ) - conversion_ratio: Optional[float] = Field( - default=None, - description="Override conversion ratio for depth to meters. If not provided, uses tool's default.", - ) -class GetGrippingPointTool(BaseROS2Tool): - name: str = "get_gripping_point" +class GetObjectGrippingPointsTool(BaseROS2Tool): + name: str = "get_object_gripping_points" description: str = "Get gripping points for specified object/objects. Returns 3D coordinates where a robot gripper can grasp the object." # Configuration for PCL components - segmentation_config: PointCloudFromSegmentationConfig - estimator_config: GrippingPointEstimatorConfig - filter_config: PointCloudFilterConfig - - # Required parameters - target_frame: str - source_frame: str - camera_topic: str - depth_topic: str - camera_info_topic: str - timeout_sec: float = 10.0 # Default timeout - conversion_ratio: float = 0.001 # Default conversion ratio + segmentation_config: PointCloudFromSegmentationConfig = Field( + default_factory=PointCloudFromSegmentationConfig, + description="Configuration for point cloud segmentation from camera images", + ) + estimator_config: GrippingPointEstimatorConfig = Field( + default_factory=GrippingPointEstimatorConfig, + description="Configuration for gripping point estimation strategies", + ) + filter_config: PointCloudFilterConfig = Field( + default_factory=PointCloudFilterConfig, + description="Configuration for point cloud filtering and outlier removal", + ) + + # Auto-initialized in model_post_init from ROS2 parameters + target_frame: Optional[str] = Field( + default=None, description="Target coordinate frame for gripping points" + ) + source_frame: Optional[str] = Field( + default=None, description="Source coordinate frame of camera data" + ) + camera_topic: Optional[str] = Field( + default=None, description="ROS2 topic for camera RGB images" + ) + depth_topic: Optional[str] = Field( + default=None, description="ROS2 topic for camera depth images" + ) + camera_info_topic: Optional[str] = Field( + default=None, description="ROS2 topic for camera calibration info" + ) + timeout_sec: Optional[float] = Field( + default=None, description="Timeout in seconds for gripping point detection" + ) + conversion_ratio: Optional[float] = Field( + default=0.001, description="Conversion ratio from depth units to meters" + ) # Components initialized in model_post_init - gripping_point_estimator: Optional[GrippingPointEstimator] = None - point_cloud_filter: Optional[PointCloudFilter] = None - point_cloud_from_segmentation: Optional[PointCloudFromSegmentation] = None + gripping_point_estimator: Optional[GrippingPointEstimator] = Field( + default=None, exclude=True + ) + point_cloud_filter: Optional[PointCloudFilter] = Field(default=None, exclude=True) + point_cloud_from_segmentation: Optional[PointCloudFromSegmentation] = Field( + default=None, exclude=True + ) - args_schema: Type[GetGrippingPointToolInput] = GetGrippingPointToolInput + args_schema: Type[GetObjectGrippingPointsToolInput] = ( + GetObjectGrippingPointsToolInput + ) def model_post_init(self, __context: Any) -> None: - """Initialize tool components.""" + """Initialize tool with ROS2 parameters and components.""" + self._load_parameters() self._initialize_components() + def _load_parameters(self) -> None: + """Load configuration from ROS2 parameters.""" + node = self.connector.node + param_prefix = PCL_DETECTION_PARAM_PREFIX + + # Declare required parameters + params = [ + f"{param_prefix}.target_frame", + f"{param_prefix}.source_frame", + f"{param_prefix}.camera_topic", + f"{param_prefix}.depth_topic", + f"{param_prefix}.camera_info_topic", + ] + + for param_name in params: + if not node.has_parameter(param_name): + raise ValueError( + f"Required parameter '{param_name}' must be set before initializing GetObjectGrippingPointsTool" + ) + + # Load parameters + self.target_frame = node.get_parameter(f"{param_prefix}.target_frame").value + self.source_frame = node.get_parameter(f"{param_prefix}.source_frame").value + self.camera_topic = node.get_parameter(f"{param_prefix}.camera_topic").value + self.depth_topic = node.get_parameter(f"{param_prefix}.depth_topic").value + self.camera_info_topic = node.get_parameter( + f"{param_prefix}.camera_info_topic" + ).value + + # timeout for gripping point detection + self.timeout_sec = ( + node.get_parameter(f"{param_prefix}.timeout_sec").value + if node.has_parameter(f"{param_prefix}.timeout_sec") + else 10.0 + ) + + # conversion ratio for point cloud from segmentation + self.conversion_ratio = ( + node.get_parameter(f"{param_prefix}.conversion_ratio").value + if node.has_parameter(f"{param_prefix}.conversion_ratio") + else 0.001 + ) + def _initialize_components(self) -> None: - """Initialize PCL components with provided parameters.""" + """Initialize PCL components with loaded parameters.""" self.point_cloud_from_segmentation = PointCloudFromSegmentation( connector=self.connector, camera_topic=self.camera_topic, @@ -82,52 +149,51 @@ def _initialize_components(self) -> None: camera_info_topic=self.camera_info_topic, source_frame=self.source_frame, target_frame=self.target_frame, - config=self.segmentation_config, conversion_ratio=self.conversion_ratio, + config=self.segmentation_config, ) self.gripping_point_estimator = GrippingPointEstimator( config=self.estimator_config ) self.point_cloud_filter = PointCloudFilter(config=self.filter_config) - def _run( - self, - object_name: str, - timeout_sec: Optional[float] = None, - conversion_ratio: Optional[float] = None, - ) -> List[np.ndarray]: - """Run gripping point detection and return raw gripping points.""" - - # Use runtime parameters if provided, otherwise use defaults - effective_timeout = timeout_sec if timeout_sec is not None else self.timeout_sec - effective_conversion_ratio = ( - conversion_ratio if conversion_ratio is not None else self.conversion_ratio - ) - - # Update conversion ratio if different from current - if effective_conversion_ratio != self.conversion_ratio: - self.point_cloud_from_segmentation.conversion_ratio = ( - effective_conversion_ratio - ) - + def _run(self, object_name: str) -> str: @timeout( - effective_timeout, - f"Gripping point detection for object '{object_name}' exceeded {effective_timeout} seconds", + self.timeout_sec, + f"Gripping point detection for object '{object_name}' exceeded {self.timeout_sec} seconds", ) def _run_with_timeout(): pcl = self.point_cloud_from_segmentation.run(object_name) if len(pcl) == 0: - return [] + return f"No {object_name}s detected." + + pcl_filtered = self.point_cloud_filter.run(pcl) + if len(pcl_filtered) == 0: + return f"No {object_name}s detected after applying filtering" + + gripping_points = self.gripping_point_estimator.run(pcl_filtered) + + message = "" + if len(gripping_points) == 0: + message += f"No gripping point found for the object {object_name}\n" + elif len(gripping_points) == 1: + message += f"The gripping point of the object {object_name} is {gripping_points[0]}\n" + else: + message += ( + f"Multiple gripping points found for the object {object_name}\n" + ) + + for i, gp in enumerate(gripping_points): + message += ( + f"The gripping point of the object {i + 1} {object_name} is {gp}\n" + ) - pcl = self.point_cloud_filter.run(pcl) - gps = self.gripping_point_estimator.run(pcl) - return gps + return message try: return _run_with_timeout() except RaiTimeoutError as e: - # Log the timeout but still raise it self.connector.node.get_logger().warning(f"Timeout: {e}") - raise # Let caller decide how to handle + return f"Timeout: Gripping point detection for object '{object_name}' exceeded {self.timeout_sec} seconds" except Exception: raise diff --git a/tests/rai_extensions/test_gripping_points.py b/tests/rai_extensions/test_gripping_points.py index d2e77e9a7..5c41939b0 100644 --- a/tests/rai_extensions/test_gripping_points.py +++ b/tests/rai_extensions/test_gripping_points.py @@ -22,29 +22,24 @@ The demo app and rivz2 need to be started before running the test. The test will fail if the gripping points are not found. Usage: -pytest tests/tools/ros2/test_gripping_points.py::test_gripping_points_manipulation_demo -m "manual" -s -v --strategy +pytest tests/rai_extensions/test_gripping_points.py::test_gripping_points_manipulation_demo -m "manual" -s -v --strategy """ -import time -from typing import List - import cv2 import numpy as np import pytest import rclpy -import rclpy.parameter from cv_bridge import CvBridge from rai.communication.ros2 import wait_for_ros2_services, wait_for_ros2_topics from rai.communication.ros2.connectors import ROS2Connector -from rai_open_set_vision import ( - GetGrippingPointTool, +from rai_open_set_vision import GetObjectGrippingPointsTool +from rai_open_set_vision.tools.pcl_detection import ( GrippingPointEstimatorConfig, PointCloudFilterConfig, PointCloudFromSegmentationConfig, + _publish_gripping_point_debug_data, ) - -# import internal tool -from rai_open_set_vision.tools.pcl_detection import _publish_gripping_point_debug_data +from rai_open_set_vision.tools.pcl_detection_tools import PCL_DETECTION_PARAM_PREFIX def draw_points_on_image(image_msg, points, camera_info): @@ -89,9 +84,20 @@ def draw_points_on_image(image_msg, points, camera_info): return cv_image -def extract_gripping_points(result: List[np.ndarray]) -> list[np.ndarray]: - """Extract gripping points from the result - now returns raw points directly.""" - return result +def extract_gripping_points(result: str) -> list[np.ndarray]: + """Extract gripping points from the result.""" + gripping_points = [] + lines = result.split("\n") + for line in lines: + if "gripping point" in line and "is [" in line: + # Extract coordinates from line like "is [0.39972728 0.16179778 0.04179673]" + start = line.find("[") + 1 + end = line.find("]") + if start > 0 and end > start: + coords_str = line[start:end] + coords = [float(x) for x in coords_str.split()] + gripping_points.append(np.array(coords)) + return gripping_points def transform_points_to_target_frame(connector, points, source_frame, target_frame): @@ -178,6 +184,7 @@ def main( frames: dict = None, estimator_config: dict = None, filter_config: dict = None, + debug_enabled: bool = False, ): # Default configuration for manipulation-demo if topics is None: @@ -204,6 +211,7 @@ def main( # Initialize ROS2 rclpy.init() + connector = ROS2Connector(executor_type="single_threaded") try: @@ -213,51 +221,58 @@ def main( wait_for_ros2_topics(connector, list(topics.values())) print("✅ All services and topics available") - # Set up conversion ratio parameter + # Set up node parameters node = connector.node - node.declare_parameter("conversion_ratio", 1.0) - start_time = time.time() + param_prefix = PCL_DETECTION_PARAM_PREFIX + # Declare and set ROS2 parameters for deployment configuration + parameters_to_set = [ + (f"{param_prefix}.target_frame", frames["target"]), + (f"{param_prefix}.source_frame", frames["source"]), + (f"{param_prefix}.camera_topic", topics["camera"]), + (f"{param_prefix}.depth_topic", topics["depth"]), + (f"{param_prefix}.camera_info_topic", topics["camera_info"]), + (f"{param_prefix}.timeout_sec", 10.0), + (f"{param_prefix}.conversion_ratio", 1.0), + ] + + # Declare and set each parameter + for param_name, param_value in parameters_to_set: + node.declare_parameter(param_name, param_value) print( f"\nTesting GetGrippingPointTool with object '{test_object}', strategy '{strategy}'" ) # Create the tool with algorithm configurations - gripping_tool = GetGrippingPointTool( + tool = GetObjectGrippingPointsTool( connector=connector, segmentation_config=PointCloudFromSegmentationConfig(), estimator_config=GrippingPointEstimatorConfig(**estimator_config), filter_config=PointCloudFilterConfig(**filter_config), ) - result = gripping_tool._run(test_object) - print(f"elapsed time: {time.time() - start_time} seconds") + pcl = tool.point_cloud_from_segmentation.run(test_object) + if len(pcl) == 0: + print(f"No {test_object}s detected.") + return + + pcl_filtered = tool.point_cloud_filter.run(pcl) + gripping_points = tool.gripping_point_estimator.run(pcl_filtered) + assert len(gripping_points) > 0, "No gripping points found" - # result is now a list of numpy arrays directly - gripping_points = result print(f"\nFound {len(gripping_points)} gripping points in target frame:") for i, gp in enumerate(gripping_points): print(f" GP{i + 1}: [{gp[0]:.3f}, {gp[1]:.3f}, {gp[2]:.3f}]") - assert len(gripping_points) > 0, "No gripping points found" - - if gripping_points: - # Call the function in pcl.py to publish the gripping point for visualization - segmented_clouds = gripping_tool.point_cloud_from_segmentation.run( - test_object - ) - filtered_clouds = gripping_tool.point_cloud_filter.run(segmented_clouds) - - print( - "\nPublishing debug data to /debug_gripping_points_pointcloud and /debug_gripping_points_markerarray" - ) + if debug_enabled: _publish_gripping_point_debug_data( - connector, filtered_clouds, gripping_points, frames["target"] + connector, + pcl_filtered, + gripping_points, + frames["target"], ) - print("✅ Debug data published") - annotated_image_path = f"{test_object}_{strategy}_gripping_points.jpg" save_annotated_image( connector, @@ -285,7 +300,7 @@ def main( @pytest.mark.manual def test_gripping_points_manipulation_demo(strategy): """Manual test requiring manipulation-demo app to be started.""" - main("cube", strategy) + main("cube", strategy, debug_enabled=True) @pytest.mark.manual diff --git a/tests/rai_extensions/test_pcl_detection_tools.py b/tests/rai_extensions/test_pcl_detection_tools.py index 7c35eac91..8ceb05967 100644 --- a/tests/rai_extensions/test_pcl_detection_tools.py +++ b/tests/rai_extensions/test_pcl_detection_tools.py @@ -27,9 +27,8 @@ import numpy as np from rai.communication.ros2.connectors import ROS2Connector -from rai.tools.timeout import RaiTimeoutError from rai_open_set_vision import ( - GetGrippingPointTool, + GetObjectGrippingPointsTool, GrippingPointEstimator, GrippingPointEstimatorConfig, PointCloudFilter, @@ -176,13 +175,8 @@ def test_get_gripping_point_tool_timeout(): mock_filter.run.return_value = [] mock_estimator.run.return_value = [] - tool = GetGrippingPointTool( + tool = GetObjectGrippingPointsTool( connector=mock_connector, - target_frame="panda_link0", - source_frame="RGBDCamera5", - camera_topic="/color_image5", - depth_topic="/depth_image5", - camera_info_topic="/color_camera_info5", segmentation_config=PointCloudFromSegmentationConfig(), estimator_config=GrippingPointEstimatorConfig(), filter_config=PointCloudFilterConfig(), @@ -195,8 +189,8 @@ def test_get_gripping_point_tool_timeout(): # Test fast execution - should complete without timeout result = tool._run("test_object") - assert result == [] # Returns empty list for no objects found - assert len(result) == 0 + assert "No test_objects detected" in result + assert "timed out" not in result.lower() # Test 2: Actual timeout behavior - should raise TimeoutError def slow_operation(obj_name): @@ -206,6 +200,8 @@ def slow_operation(obj_name): mock_pcl_gen.run.side_effect = slow_operation tool.timeout_sec = 1.0 # Short timeout - # Expect TimeoutError to be raised - with pytest.raises(RaiTimeoutError, match="exceeded 1.0 seconds"): + # Expect TimeoutError + assert ( tool._run("test") + == "Timeout: Gripping point detection for object 'test' exceeded 1.0 seconds" + ) diff --git a/xd.py b/xd.py deleted file mode 100644 index fec2dba68..000000000 --- a/xd.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (C) 2025 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from rai.agents import wait_for_shutdown -from rai.communication.ros2 import ROS2Context -from rai_open_set_vision.agents import GroundedSamAgent, GroundingDinoAgent - - -@ROS2Context() -def main(): - agent1 = GroundingDinoAgent() - agent2 = GroundedSamAgent() - agent1.run() - agent2.run() - wait_for_shutdown([agent1, agent2]) - - -if __name__ == "__main__": - main()