Skip to content

Commit 47d7090

Browse files
authored
feat: add ros2 hri connector (#410)
1 parent 6bae47f commit 47d7090

File tree

4 files changed

+293
-9
lines changed

4 files changed

+293
-9
lines changed

src/rai/rai/communication/hri_connector.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import base64
1616
from dataclasses import dataclass, field
1717
from io import BytesIO
18-
from typing import Generic, Literal, Sequence, TypeVar, get_args
18+
from typing import Any, Dict, Generic, Literal, Optional, Sequence, TypeVar, get_args
1919

2020
from langchain_core.messages import AIMessage
2121
from langchain_core.messages import BaseMessage as LangchainBaseMessage
@@ -54,9 +54,11 @@ class HRIMessage(BaseMessage):
5454
def __init__(
5555
self,
5656
payload: HRIPayload,
57-
message_author: Literal["ai", "human"],
57+
metadata: Optional[Dict[str, Any]] = None,
58+
message_author: Literal["ai", "human"] = "ai",
5859
**kwargs,
5960
):
61+
super().__init__(payload, metadata)
6062
self.message_author = message_author
6163
self.text = payload.text
6264
self.images = payload.images

src/rai/rai/communication/ros2/api.py

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import time
1818
import uuid
1919
from concurrent.futures import ThreadPoolExecutor
20+
from dataclasses import dataclass
2021
from functools import partial
2122
from typing import (
2223
Annotated,
@@ -105,9 +106,14 @@ def adapt_requests_to_offers(publisher_info: List[TopicEndpointInfo]) -> QoSProf
105106
return request_qos
106107

107108

108-
def build_ros2_msg(msg_type: str, msg_args: Dict[str, Any]) -> object:
109-
"""Build a ROS2 message instance from type string and content dictionary."""
110-
msg_cls = import_message_from_str(msg_type)
109+
def build_ros2_msg(
110+
msg_type: str | type[rclpy.node.MsgType], msg_args: Dict[str, Any]
111+
) -> object:
112+
"""Build a ROS2 message instance from string or MsgType and content dictionary."""
113+
if isinstance(msg_type, str):
114+
msg_cls = import_message_from_str(msg_type)
115+
else:
116+
msg_cls = msg_type
111117
msg = msg_cls()
112118
rosidl_runtime_py.set_message.set_message_fields(msg, msg_args)
113119
return msg
@@ -311,6 +317,93 @@ def shutdown(self) -> None:
311317
publisher.destroy()
312318

313319

320+
@dataclass
321+
class TopicConfig:
322+
name: str
323+
msg_type: str
324+
auto_qos_matching: bool = True
325+
qos_profile: Optional[QoSProfile] = None
326+
is_subscriber: bool = False
327+
subscriber_callback: Optional[Callable[[Any], None]] = None
328+
329+
def __post_init__(self):
330+
if not self.auto_qos_matching and self.qos_profile is None:
331+
raise ValueError(
332+
"Either 'auto_qos_matching' must be True or 'qos_profile' must be set."
333+
)
334+
335+
336+
class ConfigurableROS2TopicAPI(ROS2TopicAPI):
337+
338+
def __init__(self, node: rclpy.node.Node):
339+
super().__init__(node)
340+
self._subscribtions: dict[str, rclpy.node.Subscription] = {}
341+
342+
def configure_publisher(self, topic: str, config: TopicConfig):
343+
if config.is_subscriber:
344+
raise ValueError(
345+
"Can't reconfigure publisher with subscriber config! Set config.is_subscriber to False"
346+
)
347+
qos_profile = self._resolve_qos_profile(
348+
topic, config.auto_qos_matching, config.qos_profile, for_publisher=True
349+
)
350+
if topic in self._publishers:
351+
flag = self._node.destroy_publisher(self._publishers[topic].handle)
352+
if not flag:
353+
raise ValueError(f"Failed to reconfigure existing publisher to {topic}")
354+
355+
self._publishers[topic] = self._node.create_publisher(
356+
import_message_from_str(config.msg_type),
357+
topic=topic,
358+
qos_profile=qos_profile,
359+
)
360+
361+
def configure_subscriber(
362+
self,
363+
topic: str,
364+
config: TopicConfig,
365+
):
366+
if not config.is_subscriber:
367+
raise ValueError(
368+
"Can't reconfigure subscriber with publisher config! Set config.is_subscriber to True"
369+
)
370+
qos_profile = self._resolve_qos_profile(
371+
topic, config.auto_qos_matching, config.qos_profile, for_publisher=False
372+
)
373+
if topic in self._subscribtions:
374+
flag = self._node.destroy_subscription(self._subscribtions[topic])
375+
if not flag:
376+
raise ValueError(
377+
f"Failed to reconfigure existing subscriber to {topic}"
378+
)
379+
380+
assert config.subscriber_callback is not None
381+
self._subscribtions[topic] = self._node.create_subscription(
382+
msg_type=import_message_from_str(config.msg_type),
383+
topic=topic,
384+
callback=config.subscriber_callback,
385+
qos_profile=qos_profile,
386+
)
387+
388+
def publish_configured(self, topic: str, msg_content: dict[str, Any]) -> None:
389+
"""Publish a message to a ROS2 topic.
390+
391+
Args:
392+
topic: Name of the topic to publish to
393+
msg_content: Dictionary containing the message content
394+
395+
Raises:
396+
ValueError: If topic has not been configured for publishing
397+
"""
398+
try:
399+
publisher = self._publishers[topic]
400+
except Exception as e:
401+
raise ValueError(f"{topic} has not been configured for publishing") from e
402+
msg_type = publisher.msg_type
403+
msg = build_ros2_msg(msg_type, msg_content) # type: ignore
404+
publisher.publish(msg)
405+
406+
314407
class ROS2ServiceAPI:
315408
"""Handles ROS2 service operations including calling services."""
316409

src/rai/rai/communication/ros2/connectors.py

Lines changed: 105 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import threading
1616
import time
1717
import uuid
18-
from typing import Any, Callable, Dict, List, Optional, Tuple
18+
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
1919

2020
import rclpy
2121
import rclpy.executors
@@ -27,8 +27,20 @@
2727
from rclpy.qos import QoSProfile
2828
from tf2_ros import Buffer, LookupException, TransformListener, TransformStamped
2929

30-
from rai.communication.ari_connector import ARIConnector, ARIMessage
31-
from rai.communication.ros2.api import ROS2ActionAPI, ROS2ServiceAPI, ROS2TopicAPI
30+
from rai.communication import (
31+
ARIConnector,
32+
ARIMessage,
33+
HRIConnector,
34+
HRIMessage,
35+
HRIPayload,
36+
)
37+
from rai.communication.ros2.api import (
38+
ConfigurableROS2TopicAPI,
39+
ROS2ActionAPI,
40+
ROS2ServiceAPI,
41+
ROS2TopicAPI,
42+
TopicConfig,
43+
)
3244

3345

3446
class ROS2ARIMessage(ARIMessage):
@@ -183,3 +195,93 @@ def shutdown(self):
183195
self._actions_api.shutdown()
184196
self._topic_api.shutdown()
185197
self._node.destroy_node()
198+
199+
200+
class ROS2HRIMessage(HRIMessage):
201+
def __init__(self, payload: HRIPayload, message_author: Literal["ai", "human"]):
202+
super().__init__(payload, message_author)
203+
204+
205+
class ROS2HRIConnector(HRIConnector[ROS2HRIMessage]):
206+
def __init__(
207+
self,
208+
node_name: str = f"rai_ros2_hri_connector_{str(uuid.uuid4())[-12:]}",
209+
targets: List[Tuple[str, TopicConfig]] = [],
210+
sources: List[Tuple[str, TopicConfig]] = [],
211+
):
212+
configured_targets = [target[0] for target in targets]
213+
configured_sources = [source[0] for source in sources]
214+
215+
self._configure_publishers(targets)
216+
self._configure_subscribers(sources)
217+
218+
super().__init__(configured_targets, configured_sources)
219+
self._node = Node(node_name)
220+
self._topic_api = ConfigurableROS2TopicAPI(self._node)
221+
self._service_api = ROS2ServiceAPI(self._node)
222+
self._actions_api = ROS2ActionAPI(self._node)
223+
224+
self._executor = MultiThreadedExecutor()
225+
self._executor.add_node(self._node)
226+
self._thread = threading.Thread(target=self._executor.spin)
227+
self._thread.start()
228+
229+
def _configure_publishers(self, targets: List[Tuple[str, TopicConfig]]):
230+
for target in targets:
231+
self._topic_api.configure_publisher(target[0], target[1])
232+
233+
def _configure_subscribers(self, sources: List[Tuple[str, TopicConfig]]):
234+
for source in sources:
235+
self._topic_api.configure_subscriber(source[0], source[1])
236+
237+
def send_message(self, message: ROS2HRIMessage, target: str, **kwargs):
238+
self._topic_api.publish_configured(
239+
topic=target,
240+
msg_content=message.payload,
241+
)
242+
243+
def receive_message(
244+
self,
245+
source: str,
246+
timeout_sec: float = 1.0,
247+
*,
248+
message_author: Literal["human", "ai"],
249+
msg_type: Optional[str] = None,
250+
auto_topic_type: bool = True,
251+
**kwargs: Any,
252+
) -> ROS2HRIMessage:
253+
if msg_type != "std_msgs/msg/String":
254+
raise ValueError("ROS2HRIConnector only supports receiving sting messages")
255+
msg = self._topic_api.receive(
256+
topic=source,
257+
timeout_sec=timeout_sec,
258+
msg_type=msg_type,
259+
auto_topic_type=auto_topic_type,
260+
)
261+
payload = HRIPayload(msg.data)
262+
return ROS2HRIMessage(payload=payload, message_author=message_author)
263+
264+
def service_call(
265+
self, message: ROS2HRIMessage, target: str, timeout_sec: float, **kwargs: Any
266+
) -> ROS2HRIMessage:
267+
raise NotImplementedError(
268+
f"{self.__class__.__name__} doesn't support service calls"
269+
)
270+
271+
def start_action(
272+
self,
273+
action_data: Optional[ROS2HRIMessage],
274+
target: str,
275+
on_feedback: Callable,
276+
on_done: Callable,
277+
timeout_sec: float,
278+
**kwargs: Any,
279+
) -> str:
280+
raise NotImplementedError(
281+
f"{self.__class__.__name__} doesn't support action calls"
282+
)
283+
284+
def terminate_action(self, action_handle: str, **kwargs: Any):
285+
raise NotImplementedError(
286+
f"{self.__class__.__name__} doesn't support action calls"
287+
)

tests/communication/ros2/test_api.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@
2121
from rclpy.executors import MultiThreadedExecutor
2222
from rclpy.node import Node
2323

24-
from rai.communication.ros2.api import ROS2ActionAPI, ROS2ServiceAPI, ROS2TopicAPI
24+
from rai.communication.ros2.api import (
25+
ConfigurableROS2TopicAPI,
26+
ROS2ActionAPI,
27+
ROS2ServiceAPI,
28+
ROS2TopicAPI,
29+
TopicConfig,
30+
)
2531

2632
from .helpers import ActionServer_ as ActionServer
2733
from .helpers import (
@@ -59,6 +65,87 @@ def test_ros2_single_message_publish(
5965
shutdown_executors_and_threads(executors, threads)
6066

6167

68+
def test_ros2_configure_publisher(ros_setup: None, request: pytest.FixtureRequest):
69+
topic_name = f"{request.node.originalname}_topic" # type: ignore
70+
node_name = f"{request.node.originalname}_node" # type: ignore
71+
node = Node(node_name)
72+
executors, threads = multi_threaded_spinner([node])
73+
try:
74+
topic_api = ConfigurableROS2TopicAPI(node)
75+
cfg = TopicConfig(topic_name, "std_msgs/msg/String")
76+
topic_api.configure_publisher(topic_name, cfg)
77+
assert topic_api._publishers[topic_name] is not None
78+
finally:
79+
shutdown_executors_and_threads(executors, threads)
80+
81+
82+
def test_ros2_configre_subscriber(ros_setup, request: pytest.FixtureRequest):
83+
topic_name = f"{request.node.originalname}_topic" # type: ignore
84+
node_name = f"{request.node.originalname}_node" # type: ignore
85+
node = Node(node_name)
86+
executors, threads = multi_threaded_spinner([node])
87+
try:
88+
topic_api = ConfigurableROS2TopicAPI(node)
89+
cfg = TopicConfig(
90+
topic_name,
91+
"std_msgs/msg/String",
92+
is_subscriber=True,
93+
subscriber_callback=lambda _: None,
94+
)
95+
topic_api.configure_subscriber(topic_name, cfg)
96+
assert topic_api._subscribtions[topic_name] is not None
97+
finally:
98+
shutdown_executors_and_threads(executors, threads)
99+
100+
101+
def test_ros2_single_message_publish_configured(
102+
ros_setup: None, request: pytest.FixtureRequest
103+
) -> None:
104+
topic_name = f"{request.node.originalname}_topic" # type: ignore
105+
node_name = f"{request.node.originalname}_node" # type: ignore
106+
message_receiver = MessageReceiver(topic_name)
107+
node = Node(node_name)
108+
executors, threads = multi_threaded_spinner([message_receiver, node])
109+
110+
try:
111+
topic_api = ConfigurableROS2TopicAPI(node)
112+
cfg = TopicConfig(
113+
topic_name,
114+
"std_msgs/msg/String",
115+
is_subscriber=False,
116+
)
117+
topic_api.configure_publisher(topic_name, cfg)
118+
topic_api.publish_configured(
119+
topic_name,
120+
{"data": "Hello, ROS2!"},
121+
)
122+
time.sleep(1)
123+
assert len(message_receiver.received_messages) == 1
124+
assert message_receiver.received_messages[0].data == "Hello, ROS2!"
125+
finally:
126+
shutdown_executors_and_threads(executors, threads)
127+
128+
129+
def test_ros2_single_message_publish_configured_no_config(
130+
ros_setup: None, request: pytest.FixtureRequest
131+
) -> None:
132+
topic_name = f"{request.node.originalname}_topic" # type: ignore
133+
node_name = f"{request.node.originalname}_node" # type: ignore
134+
message_receiver = MessageReceiver(topic_name)
135+
node = Node(node_name)
136+
executors, threads = multi_threaded_spinner([message_receiver, node])
137+
138+
try:
139+
topic_api = ConfigurableROS2TopicAPI(node)
140+
with pytest.raises(ValueError):
141+
topic_api.publish_configured(
142+
topic_name,
143+
{"data": "Hello, ROS2!"},
144+
)
145+
finally:
146+
shutdown_executors_and_threads(executors, threads)
147+
148+
62149
def test_ros2_single_message_publish_wrong_msg_type(
63150
ros_setup: None, request: pytest.FixtureRequest
64151
) -> None:

0 commit comments

Comments
 (0)