Skip to content

Commit 07d49f8

Browse files
authored
refactor: gdino and gsam as agents (#501)
1 parent 28eb194 commit 07d49f8

File tree

11 files changed

+294
-323
lines changed

11 files changed

+294
-323
lines changed

src/rai_bringup/launch/openset.launch.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,36 +14,16 @@
1414

1515

1616
from launch import LaunchDescription
17-
from launch.actions import IncludeLaunchDescription
18-
from launch.launch_description_sources import AnyLaunchDescriptionSource
19-
from launch.substitutions import PathJoinSubstitution
20-
from launch_ros.substitutions import FindPackageShare
17+
from launch.actions import ExecuteProcess
2118

2219

2320
def generate_launch_description():
2421
return LaunchDescription(
2522
[
26-
IncludeLaunchDescription(
27-
AnyLaunchDescriptionSource(
28-
PathJoinSubstitution(
29-
[
30-
FindPackageShare("rai_open_set_vision"),
31-
"launch",
32-
"gdino_launch.xml",
33-
]
34-
)
35-
)
36-
),
37-
IncludeLaunchDescription(
38-
AnyLaunchDescriptionSource(
39-
PathJoinSubstitution(
40-
[
41-
FindPackageShare("rai_open_set_vision"),
42-
"launch",
43-
"gsam_launch.xml",
44-
]
45-
)
46-
)
23+
ExecuteProcess(
24+
cmd=["python", "run_vision_agents.py"],
25+
cwd="src/rai_extensions/rai_open_set_vision/scripts",
26+
output="screen",
4727
),
4828
]
4929
)

src/rai_core/rai/agents/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from rai.agents.base import BaseAgent
1516
from rai.agents.conversational_agent import create_conversational_agent
1617
from rai.agents.react_agent import ReActAgent
1718
from rai.agents.runner import AgentRunner, wait_for_shutdown
@@ -20,6 +21,7 @@
2021

2122
__all__ = [
2223
"AgentRunner",
24+
"BaseAgent",
2325
"ReActAgent",
2426
"ToolRunner",
2527
"create_conversational_agent",

src/rai_core/rai/tools/ros2/generic/topics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,11 @@ def _run(
135135
msg_type = type(message.payload)
136136
if msg_type == Image:
137137
image = CvBridge().imgmsg_to_cv2( # type: ignore
138-
message.payload, desired_encoding="rgb8"
138+
message.payload, desired_encoding="bgr8"
139139
)
140140
elif msg_type == CompressedImage:
141141
image = CvBridge().compressed_imgmsg_to_cv2( # type: ignore
142-
message.payload, desired_encoding="rgb8"
142+
message.payload, desired_encoding="bgr8"
143143
)
144144
else:
145145
raise ValueError(

src/rai_extensions/rai_open_set_vision/rai_open_set_vision/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,21 @@
1313
# limitations under the License.
1414

1515

16-
from .services.grounding_dino import GDINO_NODE_NAME, GDINO_SERVICE_NAME
16+
from .agents.grounded_sam import GSAM_NODE_NAME, GSAM_SERVICE_NAME, GroundedSamAgent
17+
from .agents.grounding_dino import (
18+
GDINO_NODE_NAME,
19+
GDINO_SERVICE_NAME,
20+
GroundingDinoAgent,
21+
)
1722
from .tools import GetDetectionTool, GetDistanceToObjectsTool
1823

1924
__all__ = [
2025
"GDINO_NODE_NAME",
2126
"GDINO_SERVICE_NAME",
27+
"GSAM_NODE_NAME",
28+
"GSAM_SERVICE_NAME",
2229
"GetDetectionTool",
2330
"GetDistanceToObjectsTool",
31+
"GroundedSamAgent",
32+
"GroundingDinoAgent",
2433
]

src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/__init__.py renamed to src/rai_extensions/rai_open_set_vision/rai_open_set_vision/agents/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
from .base_vision_agent import BaseVisionAgent
16+
from .grounded_sam import GroundedSamAgent
17+
from .grounding_dino import GroundingDinoAgent
18+
19+
__all__ = [
20+
"BaseVisionAgent",
21+
"GroundedSamAgent",
22+
"GroundingDinoAgent",
23+
]
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright (C) 2025 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import os
17+
import subprocess
18+
from pathlib import Path
19+
20+
from rai.agents import BaseAgent
21+
from rai.communication.ros2 import ROS2Connector
22+
23+
24+
class BaseVisionAgent(BaseAgent):
25+
WEIGHTS_URL: str = ""
26+
WEIGHTS_FILENAME: str = ""
27+
28+
def __init__(
29+
self,
30+
weights_path: str | Path = Path.home() / Path(".cache/rai/"),
31+
ros2_name: str = "",
32+
):
33+
super().__init__()
34+
self._weights_path = Path(weights_path)
35+
os.makedirs(self._weights_path, exist_ok=True)
36+
self._init_weight_path()
37+
self.weight_path = self._weights_path
38+
self.ros2_connector = ROS2Connector(ros2_name)
39+
40+
def _init_weight_path(self):
41+
try:
42+
if self.WEIGHTS_FILENAME == "":
43+
raise ValueError("WEIGHTS_FILENAME is not set")
44+
45+
install_path = (
46+
self._weights_path / "vision" / "weights" / self.WEIGHTS_FILENAME
47+
)
48+
# make sure the file exists
49+
if install_path.exists():
50+
self._weights_path = install_path
51+
else:
52+
self._download_weights(install_path)
53+
self._weights_path = install_path
54+
55+
except Exception:
56+
self.logger.error("Could not find package path")
57+
raise Exception("Could not find package path")
58+
59+
def _download_weights(self, path: Path):
60+
try:
61+
os.makedirs(path.parent, exist_ok=True)
62+
subprocess.run(
63+
[
64+
"wget",
65+
self.WEIGHTS_URL,
66+
"-O",
67+
path,
68+
"--progress=dot:giga",
69+
]
70+
)
71+
except Exception:
72+
self.logger.error("Could not download weights")
73+
raise Exception("Could not download weights")
74+
75+
def _remove_weights(self, path: Path):
76+
if path.exists():
77+
os.remove(path)
78+
79+
def stop(self):
80+
self.ros2_connector.shutdown()
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (C) 2024 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from pathlib import Path
17+
18+
import numpy as np
19+
from cv_bridge import CvBridge
20+
21+
from rai_interfaces.srv import RAIGroundedSam
22+
from rai_open_set_vision.agents.base_vision_agent import BaseVisionAgent
23+
from rai_open_set_vision.vision_markup.segmenter import GDSegmenter
24+
25+
GSAM_NODE_NAME = "grounded_sam"
26+
GSAM_SERVICE_NAME = "grounded_sam_segment"
27+
28+
29+
class GroundedSamAgent(BaseVisionAgent):
30+
WEIGHTS_URL = (
31+
"https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"
32+
)
33+
WEIGHTS_FILENAME = "sam2_hiera_large.pt"
34+
35+
def __init__(
36+
self,
37+
weights_path: str | Path = Path.home() / Path(".cache/rai"),
38+
ros2_name: str = GSAM_NODE_NAME,
39+
):
40+
super().__init__(weights_path, ros2_name)
41+
try:
42+
self._segmenter = GDSegmenter(self._weights_path)
43+
except Exception:
44+
self.logger.error(
45+
"Could not load model. The weights might be corrupted. Redownloading..."
46+
)
47+
self._remove_weights(self.weight_path)
48+
self._init_weight_path()
49+
self.segmenter = GDSegmenter(self.weight_path)
50+
51+
def run(self):
52+
self.ros2_connector.create_service(
53+
service_name=GSAM_SERVICE_NAME,
54+
on_request=self._segment_callback,
55+
service_type="rai_interfaces/srv/RAIGroundedSam",
56+
)
57+
58+
def _segment_callback(self, request, response: RAIGroundedSam.Response):
59+
received_boxes = []
60+
for detection in request.detections.detections:
61+
received_boxes.append(detection.bbox)
62+
63+
image = request.source_img
64+
65+
assert self._segmenter is not None
66+
masks = self._segmenter.get_segmentation(image, received_boxes)
67+
bridge = CvBridge()
68+
img_arr = []
69+
for mask in masks:
70+
if len(mask.shape) > 2: # Check if the mask has multiple channels
71+
mask = np.squeeze(mask)
72+
arr = (mask * 255).astype(np.uint8) # Convert binary 0/1 to 0/255
73+
img_arr.append(bridge.cv2_to_imgmsg(arr, encoding="mono8"))
74+
75+
response.masks = img_arr
76+
return response
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (C) 2024 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from pathlib import Path
17+
18+
from rai_interfaces.msg import RAIDetectionArray
19+
from rai_open_set_vision.agents.base_vision_agent import BaseVisionAgent
20+
from rai_open_set_vision.vision_markup.boxer import GDBoxer
21+
22+
GDINO_NODE_NAME = "grounding_dino"
23+
GDINO_SERVICE_NAME = "grounding_dino_classify"
24+
25+
26+
class GroundingDinoAgent(BaseVisionAgent):
27+
WEIGHTS_URL = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth"
28+
WEIGHTS_FILENAME = "groundingdino_swint_ogc.pth"
29+
30+
def __init__(
31+
self,
32+
weights_path: str | Path = Path.home() / Path(".cache/rai"),
33+
ros2_name: str = GDINO_NODE_NAME,
34+
):
35+
super().__init__(weights_path, ros2_name)
36+
try:
37+
self._boxer = GDBoxer(self._weights_path)
38+
except Exception:
39+
self.logger.error(
40+
"Could not load model. The weights might be corrupted. Redownloading..."
41+
)
42+
self._remove_weights(self.weight_path)
43+
self._init_weight_path()
44+
self.segmenter = GDBoxer(self.weight_path)
45+
46+
def run(self):
47+
self.ros2_connector.create_service(
48+
GDINO_SERVICE_NAME,
49+
self._classify_callback,
50+
service_type="rai_interfaces/srv/RAIGroundingDino",
51+
)
52+
53+
def _classify_callback(self, request, response: RAIDetectionArray):
54+
self.logger.info(
55+
f"Request received: {request.classes}, {request.box_threshold}, {request.text_threshold}"
56+
)
57+
58+
class_array = request.classes.split(",")
59+
class_array = [class_name.strip() for class_name in class_array]
60+
class_dict = {class_name: i for i, class_name in enumerate(class_array)}
61+
62+
boxes = self._boxer.get_boxes(
63+
request.source_img,
64+
class_array,
65+
request.box_threshold,
66+
request.text_threshold,
67+
)
68+
69+
ts = self.ros2_connector._node.get_clock().now().to_msg()
70+
response.detections.detections = [ # type: ignore
71+
box.to_detection_msg(class_dict, ts) # type: ignore
72+
for box in boxes
73+
]
74+
response.detections.header.stamp = ts # type: ignore
75+
response.detections.detection_classes = class_array # type: ignore
76+
77+
return response

0 commit comments

Comments
 (0)