Skip to content

Commit 9d289d6

Browse files
committed
Set isolation_forest as default for Point Cloud filtering and publish filtered pcl in test
1 parent ea865e1 commit 9d289d6

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

src/rai_core/rai/tools/ros2/detection/pcl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class GrippingPointEstimatorConfig(BaseModel):
4949

5050
class PointCloudFilterConfig(BaseModel):
5151
strategy: Literal["dbscan", "kmeans_largest_cluster", "isolation_forest", "lof"] = (
52-
"dbscan"
52+
"isolation_forest"
5353
)
5454
min_points: int = 20
5555
# DBSCAN

tests/tools/ros2/test_gripping_points.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,9 @@ def main(
203203

204204
if filter_config is None:
205205
filter_config = {
206-
"strategy": "dbscan",
207-
"dbscan_eps": 0.02,
208-
"dbscan_min_samples": 5,
206+
"strategy": "isolation_forest",
207+
"if_max_samples": "auto",
208+
"if_contamination": 0.05,
209209
}
210210

211211
services = ["/grounded_sam_segment", "/grounding_dino_classify"]
@@ -240,21 +240,21 @@ def main(
240240

241241
start_time = time.time()
242242

243+
print(
244+
f"\nTesting GetGrippingPointTool with object '{test_object}', strategy '{strategy}'"
245+
)
246+
243247
# Create the tool with algorithm configurations
244248
gripping_tool = GetGrippingPointTool(
245249
connector=connector,
246250
segmentation_config=PointCloudFromSegmentationConfig(),
247251
estimator_config=GrippingPointEstimatorConfig(**estimator_config),
248252
filter_config=PointCloudFilterConfig(**filter_config),
249253
)
250-
print(f"elapsed time: {time.time() - start_time} seconds")
251-
252-
# Test the tool directly
253-
print(
254-
f"\nTesting GetGrippingPointTool with object '{test_object}', strategy '{strategy}'"
255-
)
256254

257255
result = gripping_tool._run(test_object)
256+
print(f"elapsed time: {time.time() - start_time} seconds")
257+
258258
gripping_points = extract_gripping_points(result)
259259
print(f"\nFound {len(gripping_points)} gripping points in target frame:")
260260

@@ -268,11 +268,13 @@ def main(
268268
segmented_clouds = gripping_tool.point_cloud_from_segmentation.run(
269269
test_object
270270
)
271+
filtered_clouds = gripping_tool.point_cloud_filter.run(segmented_clouds)
272+
271273
print(
272274
"\nPublishing debug data to /debug_gripping_points_pointcloud and /debug_gripping_points_markerarray"
273275
)
274276
_publish_gripping_point_debug_data(
275-
connector, segmented_clouds, gripping_points, frames["target"]
277+
connector, filtered_clouds, gripping_points, frames["target"]
276278
)
277279
print("✅ Debug data published")
278280

@@ -327,8 +329,8 @@ def test_gripping_points_maciej_demo(strategy):
327329
"distance_threshold_m": 0.008,
328330
},
329331
filter_config={
330-
"strategy": "dbscan",
331-
"dbscan_eps": 0.02,
332-
"dbscan_min_samples": 10,
332+
"strategy": "isolation_forest",
333+
"if_max_samples": "auto",
334+
"if_contamination": 0.05,
333335
},
334336
)

0 commit comments

Comments
 (0)