Skip to content

Commit 2f724c3

Browse files
authored
feat: BaseROS2Connector (#535)
1 parent 07d49f8 commit 2f724c3

File tree

7 files changed

+52
-54
lines changed

7 files changed

+52
-54
lines changed

src/rai_core/rai/communication/base_connector.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
Dict,
2323
Generic,
2424
Optional,
25+
Type,
2526
TypeVar,
27+
get_args,
2628
)
2729
from uuid import uuid4
2830

@@ -68,6 +70,15 @@ def __init__(self, callback_max_workers: int = 4):
6870
max_workers=self.callback_max_workers
6971
)
7072

73+
if not hasattr(self, "__orig_bases__"):
74+
self.__orig_bases__ = {}
75+
raise ConnectorException(
76+
f"Error while instantiating {str(self.__class__)}: "
77+
"Message type T derived from BaseMessage needs to be provided"
78+
" e.g. Connector[MessageType]()"
79+
)
80+
self.T_class: Type[T] = get_args(self.__orig_bases__[-1])[0]
81+
7182
def _generate_handle(self) -> str:
7283
return str(uuid4())
7384

src/rai_core/rai/communication/ros2/connectors/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
# limitations under the License.
1414

1515
from .action_mixin import ROS2ActionMixin
16-
from .connector import ROS2Connector
16+
from .base import ROS2BaseConnector
1717
from .hri_connector import ROS2HRIConnector
18+
from .ros2_connector import ROS2Connector
1819
from .service_mixin import ROS2ServiceMixin
1920

2021
__all__ = [
2122
"ROS2ActionMixin",
23+
"ROS2BaseConnector",
2224
"ROS2Connector",
2325
"ROS2HRIConnector",
2426
"ROS2ServiceMixin",

src/rai_core/rai/communication/ros2/connectors/connector.py renamed to src/rai_core/rai/communication/ros2/connectors/base.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import time
1717
import uuid
1818
from functools import partial
19-
from typing import Any, Callable, Dict, List, Optional, Tuple
19+
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
2020

2121
import rclpy
2222
import rclpy.executors
@@ -38,8 +38,10 @@
3838
from rai.communication.ros2.connectors.service_mixin import ROS2ServiceMixin
3939
from rai.communication.ros2.messages import ROS2Message
4040

41+
T = TypeVar("T", bound=ROS2Message)
4142

42-
class ROS2Connector(ROS2ActionMixin, ROS2ServiceMixin, BaseConnector[ROS2Message]):
43+
44+
class ROS2BaseConnector(ROS2ActionMixin, ROS2ServiceMixin, BaseConnector[T]):
4345
"""ROS2-specific implementation of the ARIConnector.
4446
4547
This connector provides functionality for ROS2 communication through topics,
@@ -110,9 +112,9 @@ def __init__(
110112
self._thread.start()
111113

112114
# cache for last received messages
113-
self.last_msg: Dict[str, ROS2Message] = {}
115+
self.last_msg: Dict[str, T] = {}
114116

115-
def last_message_callback(self, source: str, msg: ROS2Message):
117+
def last_message_callback(self, source: str, msg: T):
116118
self.last_msg[source] = msg
117119

118120
def get_topics_names_and_types(self) -> List[Tuple[str, List[str]]]:
@@ -126,7 +128,7 @@ def get_actions_names_and_types(self) -> List[Tuple[str, List[str]]]:
126128

127129
def send_message(
128130
self,
129-
message: ROS2Message,
131+
message: T,
130132
target: str,
131133
*,
132134
msg_type: str, # TODO: allow msg_type to be None, add auto topic type detection
@@ -143,19 +145,19 @@ def send_message(
143145
)
144146

145147
def general_callback_preprocessor(self, message: Any):
146-
return ROS2Message(payload=message, metadata={"msg_type": str(type(message))})
148+
return self.T_class(payload=message, metadata={"msg_type": str(type(message))})
147149

148150
def register_callback(
149151
self,
150152
source: str,
151-
callback: Callable[[ROS2Message | Any], None],
153+
callback: Callable[[T | Any], None],
152154
raw: bool = False,
153155
*,
154156
msg_type: Optional[str] = None,
155157
qos_profile: Optional[QoSProfile] = None,
156158
auto_qos_matching: bool = True,
157159
**kwargs: Any,
158-
):
160+
) -> str:
159161
exists = self._topic_api.subscriber_exists(source)
160162
if not exists:
161163
self._topic_api.create_subscriber(
@@ -165,7 +167,7 @@ def register_callback(
165167
qos_profile=qos_profile,
166168
auto_qos_matching=auto_qos_matching,
167169
)
168-
super().register_callback(source, callback, raw=raw)
170+
return super().register_callback(source, callback, raw=raw)
169171

170172
def receive_message(
171173
self,
@@ -176,7 +178,7 @@ def receive_message(
176178
qos_profile: Optional[QoSProfile] = None,
177179
auto_qos_matching: bool = True,
178180
**kwargs: Any,
179-
) -> ROS2Message:
181+
) -> T:
180182
if self._topic_api.subscriber_exists(source):
181183
# trying to hit cache first
182184
if source in self.last_msg:

src/rai_core/rai/communication/ros2/connectors/hri_connector.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,14 @@
1313
# limitations under the License.
1414

1515
import logging
16-
import threading
1716
import uuid
1817
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
1918

20-
from rclpy.executors import MultiThreadedExecutor
21-
from rclpy.node import Node
22-
23-
from rai.communication import HRIConnector
2419
from rai.communication.ros2.api import (
2520
ConfigurableROS2TopicAPI,
26-
ROS2ActionAPI,
27-
ROS2ServiceAPI,
2821
TopicConfig,
2922
)
30-
from rai.communication.ros2.connectors.action_mixin import ROS2ActionMixin
31-
from rai.communication.ros2.connectors.service_mixin import ROS2ServiceMixin
23+
from rai.communication.ros2.connectors.base import ROS2BaseConnector
3224
from rai.communication.ros2.messages import ROS2HRIMessage
3325

3426
try:
@@ -37,7 +29,7 @@
3729
logging.warning("rai_interfaces is not installed, ROS 2 HRIMessage will not work.")
3830

3931

40-
class ROS2HRIConnector(ROS2ActionMixin, ROS2ServiceMixin, HRIConnector[ROS2HRIMessage]):
32+
class ROS2HRIConnector(ROS2BaseConnector[ROS2HRIMessage]):
4133
def __init__(
4234
self,
4335
node_name: str = f"rai_ros2_hri_connector_{str(uuid.uuid4())[-12:]}",
@@ -50,6 +42,8 @@ def __init__(
5042
configured_sources = [
5143
source[0] if isinstance(source, tuple) else source for source in sources
5244
]
45+
self.configured_targets = configured_targets
46+
self.configured_sources = configured_sources
5347

5448
_targets = [
5549
(
@@ -67,22 +61,11 @@ def __init__(
6761
)
6862
for source in sources
6963
]
70-
71-
self._node = Node(node_name)
64+
super().__init__(node_name=node_name)
7265
self._topic_api = ConfigurableROS2TopicAPI(self._node)
73-
self._service_api = ROS2ServiceAPI(self._node)
74-
self._actions_api = ROS2ActionAPI(self._node)
75-
7666
self._configure_publishers(_targets)
7767
self._configure_subscribers(_sources)
7868

79-
super().__init__(configured_targets, configured_sources)
80-
81-
self._executor = MultiThreadedExecutor()
82-
self._executor.add_node(self._node)
83-
self._thread = threading.Thread(target=self._executor.spin)
84-
self._thread.start()
85-
8669
def _configure_publishers(self, targets: List[Tuple[str, TopicConfig]]):
8770
for target in targets:
8871
self._topic_api.configure_publisher(target[0], target[1])
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (C) 2025 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from rai.communication.ros2.connectors.base import ROS2BaseConnector
16+
from rai.communication.ros2.messages import ROS2Message
17+
18+
19+
class ROS2Connector(ROS2BaseConnector[ROS2Message]):
20+
pass

src/rai_core/rai/communication/ros2/messages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class ROS2Message(BaseMessage):
4141
pass
4242

4343

44-
class ROS2HRIMessage(HRIMessage):
44+
class ROS2HRIMessage(HRIMessage, ROS2Message):
4545
@classmethod
4646
def from_ros2(
4747
cls,

tests/rai_sim/test_o3de_bridge.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -301,20 +301,6 @@ def test_send_message_signature(self):
301301
f"Parameter names do not match, expected: {list(expected_params.keys())}, got: {list(parameters.keys())}",
302302
)
303303

304-
for param_name, expected_type in expected_params.items():
305-
param = parameters[param_name]
306-
self.assertEqual(
307-
self.resolve_annotation(param.annotation),
308-
self.resolve_annotation(expected_type),
309-
f"Parameter '{param_name}' has incorrect type, expected: {expected_type}, got: {param.annotation}",
310-
)
311-
312-
self.assertIs(
313-
signature.return_annotation,
314-
inspect.Signature.empty,
315-
"send_message should have no return value",
316-
)
317-
318304
def test_receive_message_signature(self):
319305
signature = inspect.signature(self.connector.receive_message)
320306
parameters = signature.parameters
@@ -341,12 +327,6 @@ def test_receive_message_signature(self):
341327
f"Parameter '{param_name}' has incorrect type, expected: {expected_type}, got: {param.annotation}",
342328
)
343329

344-
self.assertIs(
345-
signature.return_annotation,
346-
ROS2Message,
347-
f"Return type is incorrect, expected: ROS2Message, got: {signature.return_annotation}",
348-
)
349-
350330
def test_get_topics_names_and_types_signature(self):
351331
signature = inspect.signature(self.connector.get_topics_names_and_types)
352332
parameters = signature.parameters

0 commit comments

Comments
 (0)