15
15
import threading
16
16
import time
17
17
import uuid
18
- from typing import Any , Callable , Dict , List , Literal , Optional , Tuple
18
+ from collections import OrderedDict
19
+ from typing import Any , Callable , Dict , List , Literal , Optional , Tuple , Union , cast
19
20
21
+ import numpy as np
20
22
import rclpy
21
23
import rclpy .executors
22
24
import rclpy .node
23
25
import rclpy .time
26
+ import rosidl_runtime_py .convert
27
+ from cv_bridge import CvBridge
28
+ from PIL import Image
29
+ from pydub import AudioSegment
24
30
from rclpy .duration import Duration
25
31
from rclpy .executors import MultiThreadedExecutor
26
32
from rclpy .node import Node
27
33
from rclpy .qos import QoSProfile
34
+ from sensor_msgs .msg import Image as ROS2Image
28
35
from tf2_ros import Buffer , LookupException , TransformListener , TransformStamped
29
36
37
+ import rai_interfaces .msg
30
38
from rai .communication import (
31
39
ARIConnector ,
32
40
ARIMessage ,
41
49
ROS2TopicAPI ,
42
50
TopicConfig ,
43
51
)
52
+ from rai_interfaces .msg import HRIMessage as ROS2HRIMessage_
53
+ from rai_interfaces .msg ._audio_message import (
54
+ AudioMessage as ROS2HRIMessage__Audio ,
55
+ )
44
56
45
57
46
58
class ROS2ARIMessage (ARIMessage ):
@@ -200,26 +212,95 @@ class ROS2HRIMessage(HRIMessage):
200
212
def __init__ (self , payload : HRIPayload , message_author : Literal ["ai" , "human" ]):
201
213
super ().__init__ (payload , message_author )
202
214
215
+ @classmethod
216
+ def from_ros2 (
217
+ cls , msg : rai_interfaces .msg .HRIMessage , message_author : Literal ["ai" , "human" ]
218
+ ):
219
+ cv_bridge = CvBridge ()
220
+ images = [
221
+ cv_bridge .imgmsg_to_cv2 (img_msg , "rgb8" )
222
+ for img_msg in cast (List [ROS2Image ], msg .images )
223
+ ]
224
+ pil_images = [Image .fromarray (img ) for img in images ]
225
+ audio_segments = [
226
+ AudioSegment (
227
+ data = audio_msg .audio ,
228
+ frame_rate = audio_msg .sample_rate ,
229
+ sample_width = 2 , # bytes, int16
230
+ channels = audio_msg .channels ,
231
+ )
232
+ for audio_msg in msg .audios
233
+ ]
234
+ return ROS2HRIMessage (
235
+ payload = HRIPayload (text = msg .text , images = pil_images , audios = audio_segments ),
236
+ message_author = message_author ,
237
+ )
238
+
239
+ def to_ros2_dict (self ) -> OrderedDict [str , Any ]:
240
+ cv_bridge = CvBridge ()
241
+ assert isinstance (self .payload , HRIPayload )
242
+ img_msgs = [
243
+ cv_bridge .cv2_to_imgmsg (np .array (img ), "rgb8" )
244
+ for img in self .payload .images
245
+ ]
246
+ audio_msgs = [
247
+ ROS2HRIMessage__Audio (
248
+ audio = audio .raw_data ,
249
+ sample_rate = audio .frame_rate ,
250
+ channels = audio .channels ,
251
+ )
252
+ for audio in self .payload .audios
253
+ ]
254
+
255
+ return cast (
256
+ OrderedDict [str , Any ],
257
+ rosidl_runtime_py .convert .message_to_ordereddict (
258
+ ROS2HRIMessage_ (
259
+ text = self .payload .text ,
260
+ images = img_msgs ,
261
+ audios = audio_msgs ,
262
+ )
263
+ ),
264
+ )
265
+
203
266
204
267
class ROS2HRIConnector (HRIConnector [ROS2HRIMessage ]):
205
268
def __init__ (
206
269
self ,
207
270
node_name : str = f"rai_ros2_hri_connector_{ str (uuid .uuid4 ())[- 12 :]} " ,
208
- targets : List [Tuple [str , TopicConfig ]] = [],
209
- sources : List [Tuple [str , TopicConfig ]] = [],
271
+ targets : List [Union [ str , Tuple [str , TopicConfig ] ]] = [],
272
+ sources : List [Union [ str , Tuple [str , TopicConfig ] ]] = [],
210
273
):
211
- configured_targets = [target [0 ] for target in targets ]
212
- configured_sources = [source [0 ] for source in sources ]
274
+ configured_targets = [
275
+ target [0 ] if isinstance (target , tuple ) else target for target in targets
276
+ ]
277
+ configured_sources = [
278
+ source [0 ] if isinstance (source , tuple ) else source for source in sources
279
+ ]
213
280
214
- self ._configure_publishers (targets )
215
- self ._configure_subscribers (sources )
281
+ _targets = [
282
+ target
283
+ if isinstance (target , tuple )
284
+ else (target , TopicConfig (is_subscriber = False ))
285
+ for target in targets
286
+ ]
287
+ _sources = [
288
+ source
289
+ if isinstance (source , tuple )
290
+ else (source , TopicConfig (is_subscriber = True ))
291
+ for source in sources
292
+ ]
216
293
217
- super ().__init__ (configured_targets , configured_sources )
218
294
self ._node = Node (node_name )
219
295
self ._topic_api = ConfigurableROS2TopicAPI (self ._node )
220
296
self ._service_api = ROS2ServiceAPI (self ._node )
221
297
self ._actions_api = ROS2ActionAPI (self ._node )
222
298
299
+ self ._configure_publishers (_targets )
300
+ self ._configure_subscribers (_sources )
301
+
302
+ super ().__init__ (configured_targets , configured_sources )
303
+
223
304
self ._executor = MultiThreadedExecutor ()
224
305
self ._executor .add_node (self ._node )
225
306
self ._thread = threading .Thread (target = self ._executor .spin )
@@ -236,7 +317,7 @@ def _configure_subscribers(self, sources: List[Tuple[str, TopicConfig]]):
236
317
def send_message (self , message : ROS2HRIMessage , target : str , ** kwargs ):
237
318
self ._topic_api .publish_configured (
238
319
topic = target ,
239
- msg_content = message .payload ,
320
+ msg_content = message .to_ros2_dict () ,
240
321
)
241
322
242
323
def receive_message (
@@ -249,16 +330,12 @@ def receive_message(
249
330
auto_topic_type : bool = True ,
250
331
** kwargs : Any ,
251
332
) -> ROS2HRIMessage :
252
- if msg_type != "std_msgs/msg/String" :
253
- raise ValueError ("ROS2HRIConnector only supports receiving sting messages" )
254
333
msg = self ._topic_api .receive (
255
334
topic = source ,
256
335
timeout_sec = timeout_sec ,
257
- msg_type = msg_type ,
258
336
auto_topic_type = auto_topic_type ,
259
337
)
260
- payload = HRIPayload (msg .data )
261
- return ROS2HRIMessage (payload = payload , message_author = message_author )
338
+ return ROS2HRIMessage .from_ros2 (msg , message_author )
262
339
263
340
def service_call (
264
341
self , message : ROS2HRIMessage , target : str , timeout_sec : float , ** kwargs : Any
@@ -284,3 +361,10 @@ def terminate_action(self, action_handle: str, **kwargs: Any):
284
361
raise NotImplementedError (
285
362
f"{ self .__class__ .__name__ } doesn't support action calls"
286
363
)
364
+
365
+ def shutdown (self ):
366
+ self ._executor .shutdown ()
367
+ self ._thread .join ()
368
+ self ._actions_api .shutdown ()
369
+ self ._topic_api .shutdown ()
370
+ self ._node .destroy_node ()
0 commit comments