Skip to content

Commit ea865e1

Browse files
committed
Support configuration for detection pipeline
1 parent 9958ca7 commit ea865e1

File tree

5 files changed

+391
-241
lines changed

5 files changed

+391
-241
lines changed

examples/manipulation-demo-v2.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright (C) 2024 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language goveself.rning permissions and
13+
# limitations under the License.
14+
15+
16+
import logging
17+
from typing import List
18+
19+
import rclpy
20+
import rclpy.qos
21+
from langchain_core.messages import BaseMessage, HumanMessage
22+
from langchain_core.tools import BaseTool
23+
from rai import get_llm_model
24+
from rai.agents.langchain.core import create_conversational_agent
25+
from rai.communication.ros2 import wait_for_ros2_services, wait_for_ros2_topics
26+
from rai.communication.ros2.connectors import ROS2Connector
27+
from rai.tools.ros2.detection.pcl import (
28+
GrippingPointEstimatorConfig,
29+
PointCloudFilterConfig,
30+
PointCloudFromSegmentationConfig,
31+
)
32+
from rai.tools.ros2.detection.tools import GetGrippingPointTool
33+
from rai.tools.ros2.manipulation import (
34+
MoveObjectFromToTool,
35+
ResetArmTool,
36+
)
37+
from rai.tools.ros2.simple import GetROS2ImageConfiguredTool
38+
39+
from rai_whoami.models import EmbodimentInfo
40+
41+
logger = logging.getLogger(__name__)
42+
43+
44+
def create_agent():
45+
rclpy.init()
46+
connector = ROS2Connector(executor_type="single_threaded")
47+
48+
required_services = ["/grounded_sam_segment", "/grounding_dino_classify"]
49+
required_topics = ["/color_image5", "/depth_image5", "/color_camera_info5"]
50+
wait_for_ros2_services(connector, required_services)
51+
wait_for_ros2_topics(connector, required_topics)
52+
53+
node = connector.node
54+
55+
# Declare and set parameters for GetGrippingPointTool
56+
# These also can be set in the launch file or during runtime
57+
parameters_to_set = [
58+
("conversion_ratio", 1.0),
59+
("detection_tools.gripping_point.target_frame", "panda_link0"),
60+
("detection_tools.gripping_point.source_frame", "RGBDCamera5"),
61+
("detection_tools.gripping_point.camera_topic", "/color_image5"),
62+
("detection_tools.gripping_point.depth_topic", "/depth_image5"),
63+
("detection_tools.gripping_point.camera_info_topic", "/color_camera_info5"),
64+
]
65+
66+
# Declare and set each parameter (timeout_sec handled by tool internally)
67+
for param_name, param_value in parameters_to_set:
68+
node.declare_parameter(param_name, param_value)
69+
70+
# Configure gripping point detection algorithms
71+
segmentation_config = PointCloudFromSegmentationConfig(
72+
box_threshold=0.35,
73+
text_threshold=0.45,
74+
)
75+
76+
estimator_config = GrippingPointEstimatorConfig(
77+
strategy="biggest_plane", # Options: "centroid", "top_plane", "biggest_plane"
78+
top_percentile=0.05,
79+
plane_bin_size_m=0.01,
80+
ransac_iterations=200,
81+
distance_threshold_m=0.01,
82+
min_points=10,
83+
)
84+
85+
filter_config = PointCloudFilterConfig(
86+
strategy="dbscan",
87+
min_points=20,
88+
dbscan_eps=0.02,
89+
dbscan_min_samples=10,
90+
)
91+
92+
tools: List[BaseTool] = [
93+
GetGrippingPointTool(
94+
connector=connector,
95+
segmentation_config=segmentation_config,
96+
estimator_config=estimator_config,
97+
filter_config=filter_config,
98+
),
99+
MoveObjectFromToTool(connector=connector, manipulator_frame="panda_link0"),
100+
ResetArmTool(connector=connector, manipulator_frame="panda_link0"),
101+
GetROS2ImageConfiguredTool(connector=connector, topic="/color_image5"),
102+
]
103+
104+
llm = get_llm_model(model_type="complex_model", streaming=True)
105+
embodiment_info = EmbodimentInfo.from_file(
106+
"examples/embodiments/manipulation_embodiment.json"
107+
)
108+
agent = create_conversational_agent(
109+
llm=llm,
110+
tools=tools,
111+
system_prompt=embodiment_info.to_langchain(),
112+
)
113+
return agent
114+
115+
116+
def main():
117+
agent = create_agent()
118+
messages: List[BaseMessage] = []
119+
120+
while True:
121+
prompt = input("Enter a prompt: ")
122+
messages.append(HumanMessage(content=prompt))
123+
output = agent.invoke({"messages": messages})
124+
output["messages"][-1].pretty_print()
125+
126+
127+
if __name__ == "__main__":
128+
main()

0 commit comments

Comments
 (0)