Skip to content

Commit 197f61e

Browse files
authored
feat: unified ROS2 hri message (#416)
Further work: #418
1 parent ce4885c commit 197f61e

File tree

8 files changed

+231
-26
lines changed

8 files changed

+231
-26
lines changed

src/rai_core/rai/communication/ros2/api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,7 @@ def shutdown(self) -> None:
319319

320320
@dataclass
321321
class TopicConfig:
322-
name: str
323-
msg_type: str
322+
msg_type: str = "rai_interfaces/msg/HRIMessage"
324323
auto_qos_matching: bool = True
325324
qos_profile: Optional[QoSProfile] = None
326325
is_subscriber: bool = False

src/rai_core/rai/communication/ros2/connectors.py

Lines changed: 98 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,26 @@
1515
import threading
1616
import time
1717
import uuid
18-
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
18+
from collections import OrderedDict
19+
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, cast
1920

21+
import numpy as np
2022
import rclpy
2123
import rclpy.executors
2224
import rclpy.node
2325
import rclpy.time
26+
import rosidl_runtime_py.convert
27+
from cv_bridge import CvBridge
28+
from PIL import Image
29+
from pydub import AudioSegment
2430
from rclpy.duration import Duration
2531
from rclpy.executors import MultiThreadedExecutor
2632
from rclpy.node import Node
2733
from rclpy.qos import QoSProfile
34+
from sensor_msgs.msg import Image as ROS2Image
2835
from tf2_ros import Buffer, LookupException, TransformListener, TransformStamped
2936

37+
import rai_interfaces.msg
3038
from rai.communication import (
3139
ARIConnector,
3240
ARIMessage,
@@ -41,6 +49,10 @@
4149
ROS2TopicAPI,
4250
TopicConfig,
4351
)
52+
from rai_interfaces.msg import HRIMessage as ROS2HRIMessage_
53+
from rai_interfaces.msg._audio_message import (
54+
AudioMessage as ROS2HRIMessage__Audio,
55+
)
4456

4557

4658
class ROS2ARIMessage(ARIMessage):
@@ -200,26 +212,95 @@ class ROS2HRIMessage(HRIMessage):
200212
def __init__(self, payload: HRIPayload, message_author: Literal["ai", "human"]):
201213
super().__init__(payload, message_author)
202214

215+
@classmethod
216+
def from_ros2(
217+
cls, msg: rai_interfaces.msg.HRIMessage, message_author: Literal["ai", "human"]
218+
):
219+
cv_bridge = CvBridge()
220+
images = [
221+
cv_bridge.imgmsg_to_cv2(img_msg, "rgb8")
222+
for img_msg in cast(List[ROS2Image], msg.images)
223+
]
224+
pil_images = [Image.fromarray(img) for img in images]
225+
audio_segments = [
226+
AudioSegment(
227+
data=audio_msg.audio,
228+
frame_rate=audio_msg.sample_rate,
229+
sample_width=2, # bytes, int16
230+
channels=audio_msg.channels,
231+
)
232+
for audio_msg in msg.audios
233+
]
234+
return ROS2HRIMessage(
235+
payload=HRIPayload(text=msg.text, images=pil_images, audios=audio_segments),
236+
message_author=message_author,
237+
)
238+
239+
def to_ros2_dict(self) -> OrderedDict[str, Any]:
240+
cv_bridge = CvBridge()
241+
assert isinstance(self.payload, HRIPayload)
242+
img_msgs = [
243+
cv_bridge.cv2_to_imgmsg(np.array(img), "rgb8")
244+
for img in self.payload.images
245+
]
246+
audio_msgs = [
247+
ROS2HRIMessage__Audio(
248+
audio=audio.raw_data,
249+
sample_rate=audio.frame_rate,
250+
channels=audio.channels,
251+
)
252+
for audio in self.payload.audios
253+
]
254+
255+
return cast(
256+
OrderedDict[str, Any],
257+
rosidl_runtime_py.convert.message_to_ordereddict(
258+
ROS2HRIMessage_(
259+
text=self.payload.text,
260+
images=img_msgs,
261+
audios=audio_msgs,
262+
)
263+
),
264+
)
265+
203266

204267
class ROS2HRIConnector(HRIConnector[ROS2HRIMessage]):
205268
def __init__(
206269
self,
207270
node_name: str = f"rai_ros2_hri_connector_{str(uuid.uuid4())[-12:]}",
208-
targets: List[Tuple[str, TopicConfig]] = [],
209-
sources: List[Tuple[str, TopicConfig]] = [],
271+
targets: List[Union[str, Tuple[str, TopicConfig]]] = [],
272+
sources: List[Union[str, Tuple[str, TopicConfig]]] = [],
210273
):
211-
configured_targets = [target[0] for target in targets]
212-
configured_sources = [source[0] for source in sources]
274+
configured_targets = [
275+
target[0] if isinstance(target, tuple) else target for target in targets
276+
]
277+
configured_sources = [
278+
source[0] if isinstance(source, tuple) else source for source in sources
279+
]
213280

214-
self._configure_publishers(targets)
215-
self._configure_subscribers(sources)
281+
_targets = [
282+
target
283+
if isinstance(target, tuple)
284+
else (target, TopicConfig(is_subscriber=False))
285+
for target in targets
286+
]
287+
_sources = [
288+
source
289+
if isinstance(source, tuple)
290+
else (source, TopicConfig(is_subscriber=True))
291+
for source in sources
292+
]
216293

217-
super().__init__(configured_targets, configured_sources)
218294
self._node = Node(node_name)
219295
self._topic_api = ConfigurableROS2TopicAPI(self._node)
220296
self._service_api = ROS2ServiceAPI(self._node)
221297
self._actions_api = ROS2ActionAPI(self._node)
222298

299+
self._configure_publishers(_targets)
300+
self._configure_subscribers(_sources)
301+
302+
super().__init__(configured_targets, configured_sources)
303+
223304
self._executor = MultiThreadedExecutor()
224305
self._executor.add_node(self._node)
225306
self._thread = threading.Thread(target=self._executor.spin)
@@ -236,7 +317,7 @@ def _configure_subscribers(self, sources: List[Tuple[str, TopicConfig]]):
236317
def send_message(self, message: ROS2HRIMessage, target: str, **kwargs):
237318
self._topic_api.publish_configured(
238319
topic=target,
239-
msg_content=message.payload,
320+
msg_content=message.to_ros2_dict(),
240321
)
241322

242323
def receive_message(
@@ -249,16 +330,12 @@ def receive_message(
249330
auto_topic_type: bool = True,
250331
**kwargs: Any,
251332
) -> ROS2HRIMessage:
252-
if msg_type != "std_msgs/msg/String":
253-
raise ValueError("ROS2HRIConnector only supports receiving sting messages")
254333
msg = self._topic_api.receive(
255334
topic=source,
256335
timeout_sec=timeout_sec,
257-
msg_type=msg_type,
258336
auto_topic_type=auto_topic_type,
259337
)
260-
payload = HRIPayload(msg.data)
261-
return ROS2HRIMessage(payload=payload, message_author=message_author)
338+
return ROS2HRIMessage.from_ros2(msg, message_author)
262339

263340
def service_call(
264341
self, message: ROS2HRIMessage, target: str, timeout_sec: float, **kwargs: Any
@@ -284,3 +361,10 @@ def terminate_action(self, action_handle: str, **kwargs: Any):
284361
raise NotImplementedError(
285362
f"{self.__class__.__name__} doesn't support action calls"
286363
)
364+
365+
def shutdown(self):
366+
self._executor.shutdown()
367+
self._thread.join()
368+
self._actions_api.shutdown()
369+
self._topic_api.shutdown()
370+
self._node.destroy_node()

src/rai_interfaces/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ rosidl_generate_interfaces(${PROJECT_NAME}
2121
"srv/VectorStoreRetrieval.srv"
2222
"srv/StringList.srv"
2323
"msg/RAIDetectionArray.msg"
24+
"msg/AudioMessage.msg"
25+
"msg/HRIMessage.msg"
2426
"srv/RAIGroundingDino.srv"
2527
"srv/RAIGroundedSam.srv"
2628
"action/Task.action"
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#
2+
# Copyright (C) 2024 Robotec.AI
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
int16[] audio
18+
uint16 sample_rate
19+
uint16 channels
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#
2+
# Copyright (C) 2024 Robotec.AI
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
std_msgs/Header header
18+
string text
19+
sensor_msgs/Image[] images
20+
rai_interfaces/AudioMessage[] audios

tests/communication/ros2/helpers.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
import numpy as np
2020
import pytest
2121
import rclpy
22+
from cv_bridge import CvBridge
2223
from nav2_msgs.action import NavigateToPose
24+
from pydub import AudioSegment
2325
from rclpy.action import ActionServer, CancelResponse, GoalResponse
2426
from rclpy.action.server import ServerGoalHandle
2527
from rclpy.callback_groups import ReentrantCallbackGroup
@@ -30,6 +32,36 @@
3032
from std_srvs.srv import SetBool
3133
from tf2_ros import TransformBroadcaster, TransformStamped
3234

35+
from rai_interfaces.msg import HRIMessage
36+
37+
38+
class HRIMessagePublisher(Node):
39+
def __init__(self, topic: str):
40+
super().__init__("test_hri_message_publisher")
41+
self.publisher = self.create_publisher(HRIMessage, topic, 10)
42+
self.timer = self.create_timer(0.1, self.publish_message)
43+
self.cv_bridge = CvBridge()
44+
45+
def publish_message(self) -> None:
46+
msg = HRIMessage()
47+
image = self.cv_bridge.cv2_to_imgmsg(np.zeros((100, 100, 3), dtype=np.uint8))
48+
msg.images = [image]
49+
msg.audios = [AudioSegment.silent(duration=1000)]
50+
msg.text = "Hello, HRI!"
51+
self.publisher.publish(msg)
52+
53+
54+
class HRIMessageSubscriber(Node):
55+
def __init__(self, topic: str):
56+
super().__init__("test_hri_message_subscriber")
57+
self.subscription = self.create_subscription(
58+
HRIMessage, topic, self.handle_test_message, 10
59+
)
60+
self.received_messages: List[HRIMessage] = []
61+
62+
def handle_test_message(self, msg: HRIMessage) -> None:
63+
self.received_messages.append(msg)
64+
3365

3466
class ServiceServer(Node):
3567
def __init__(self, service_name: str):

tests/communication/ros2/test_api.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
from .helpers import ActionServer_ as ActionServer
3232
from .helpers import (
33+
HRIMessageSubscriber,
3334
MessagePublisher,
3435
MessageReceiver,
3536
ServiceServer,
@@ -71,23 +72,21 @@ def test_ros2_configure_publisher(ros_setup: None, request: pytest.FixtureReques
7172
executors, threads = multi_threaded_spinner([node])
7273
try:
7374
topic_api = ConfigurableROS2TopicAPI(node)
74-
cfg = TopicConfig(topic_name, "std_msgs/msg/String")
75+
cfg = TopicConfig()
7576
topic_api.configure_publisher(topic_name, cfg)
7677
assert topic_api._publishers[topic_name] is not None
7778
finally:
7879
shutdown_executors_and_threads(executors, threads)
7980

8081

81-
def test_ros2_configre_subscriber(ros_setup, request: pytest.FixtureRequest):
82+
def test_ros2_configure_subscriber(ros_setup, request: pytest.FixtureRequest):
8283
topic_name = f"{request.node.originalname}_topic" # type: ignore
8384
node_name = f"{request.node.originalname}_node" # type: ignore
8485
node = Node(node_name)
8586
executors, threads = multi_threaded_spinner([node])
8687
try:
8788
topic_api = ConfigurableROS2TopicAPI(node)
8889
cfg = TopicConfig(
89-
topic_name,
90-
"std_msgs/msg/String",
9190
is_subscriber=True,
9291
subscriber_callback=lambda _: None,
9392
)
@@ -102,25 +101,23 @@ def test_ros2_single_message_publish_configured(
102101
) -> None:
103102
topic_name = f"{request.node.originalname}_topic" # type: ignore
104103
node_name = f"{request.node.originalname}_node" # type: ignore
105-
message_receiver = MessageReceiver(topic_name)
104+
message_receiver = HRIMessageSubscriber(topic_name)
106105
node = Node(node_name)
107106
executors, threads = multi_threaded_spinner([message_receiver, node])
108107

109108
try:
110109
topic_api = ConfigurableROS2TopicAPI(node)
111110
cfg = TopicConfig(
112-
topic_name,
113-
"std_msgs/msg/String",
114111
is_subscriber=False,
115112
)
116113
topic_api.configure_publisher(topic_name, cfg)
117114
topic_api.publish_configured(
118115
topic_name,
119-
{"data": "Hello, ROS2!"},
116+
{"text": "Hello, ROS2!"},
120117
)
121118
time.sleep(1)
122119
assert len(message_receiver.received_messages) == 1
123-
assert message_receiver.received_messages[0].data == "Hello, ROS2!"
120+
assert message_receiver.received_messages[0].text == "Hello, ROS2!"
124121
finally:
125122
shutdown_executors_and_threads(executors, threads)
126123

0 commit comments

Comments
 (0)