diff --git a/formant_ros2_adapter/scripts/README.md b/formant_ros2_adapter/scripts/README.md new file mode 100644 index 0000000..49d287a --- /dev/null +++ b/formant_ros2_adapter/scripts/README.md @@ -0,0 +1,4 @@ +## Tests (Python) + +1. At the root folder (in formant_ros2_adapter/scripts) run the following command + - `python -m unittest ./tests/test_subscribers_batched.py` \ No newline at end of file diff --git a/formant_ros2_adapter/scripts/components/subscriber/base_ingester.py b/formant_ros2_adapter/scripts/components/subscriber/base_ingester.py new file mode 100644 index 0000000..35c15fe --- /dev/null +++ b/formant_ros2_adapter/scripts/components/subscriber/base_ingester.py @@ -0,0 +1,174 @@ +from cv_bridge import CvBridge +import cv2 +import grpc +from typing import Dict +from sensor_msgs.msg import ( + BatteryState, + CompressedImage, + Image, + LaserScan, + NavSatFix, + PointCloud2, +) +from .types import STRING_TYPES, BOOL_TYPES, NUMERIC_TYPES, OTHER_DATA_TYPES + +from formant.sdk.agent.v1 import Client +from formant.protos.model.v1.datapoint_pb2 import Datapoint +from formant.sdk.agent.v1.localization.types import ( + PointCloud as FPointCloud, + Map as FMap, + Path as FPath, + Transform as FTransform, + Goal as FGoal, + Odometry as FOdometry, + Vector3 as FVector3, + Quaternion as FQuaternion, +) + +from utils.logger import get_logger +from ros2_utils.message_utils import ( + get_ros2_type_from_string, + message_to_json, + get_message_path_value, +) + +""" +A Handle Exceptions Class would be nice +""" + + +class BaseIngester: + def __init__(self, _fclient: Client): + self._fclient = _fclient + self.cv_bridge = CvBridge() + self._logger = get_logger() + + def prepare( + self, + msg, + msg_type: type, + formant_stream: str, + topic: str, + msg_timestamp: int, + tags: Dict, + ): + msg = self._preprocess(msg, msg_type) + + if msg_type in STRING_TYPES: + msg = self._fclient.prepare_text( + stream=formant_stream, value=msg, tags=tags, timestamp=msg_timestamp + ) + + elif msg_type in BOOL_TYPES: + msg = self._fclient.prepare_bitset( + stream=formant_stream, value=msg, tags=tags, timestamp=msg_timestamp + ) + elif msg_type in NUMERIC_TYPES: + msg = self._fclient.prepare_numeric( + stream=formant_stream, value=msg, tags=tags, timestamp=msg_timestamp + ) + + elif msg_type == NavSatFix: + + msg = self._fclient.prepare_geolocation( + stream=formant_stream, + latitude=msg.latitude, + longitude=msg.longitude, + altitude=msg.altitude, + tags=tags, + timestamp=msg_timestamp, + ) + + elif msg_type == Image: + msg = self._fclient.prepare_image( + stream=formant_stream, + value=msg, + tags=tags, + timestamp=msg_timestamp, + ) + elif msg_type == CompressedImage: + msg = self._fclient.prepare_image( + stream=formant_stream, + value=msg["value"], + content_type=msg["content_type"], + tags=tags, + timestamp=msg_timestamp, + ) + + elif msg_type == BatteryState: + msg = self._fclient.prepare_battery( + stream=formant_stream, + percentage=msg.percentage, + voltage=msg.voltage, + current=msg.current, + charge=msg.charge, + tags=tags, + timestamp=msg_timestamp, + ) + + elif msg_type == LaserScan: + msg = Datapoint( + stream=formant_stream, + point_cloud=FPointCloud.from_ros_laserscan(msg).to_proto(), + tags=tags, + timestamp=msg_timestamp, + ) + + elif msg_type == PointCloud2: + msg = Datapoint( + stream=formant_stream, + point_cloud=FPointCloud.from_ros(msg).to_proto(), + tags=tags, + timestamp=msg_timestamp, + ) + + else: + msg = self._fclient.prepare_json( + stream=formant_stream, + value=msg, + tags=tags, + timestamp=msg_timestamp, + ) + return msg + + def _preprocess(self, msg, msg_type: type): + + if msg_type in STRING_TYPES: + msg = self._prepare_string(msg) + elif msg_type in BOOL_TYPES or msg_type in NUMERIC_TYPES: + msg = self._prepare_attr_data(msg) + elif msg_type == Image: + msg = self._prepare_image(msg) + + elif msg_type == CompressedImage: + msg = self._prepare_compressed_image(msg) + + elif msg_type not in OTHER_DATA_TYPES: + msg = message_to_json(msg) + + return msg + + def _prepare_string(self, msg): + msg = self._prepare_attr_data(msg) + msg = str(msg) + return msg + + def _prepare_image(self, msg): + cv_image = self.cv_bridge.imgmsg_to_cv2(msg, "bgr8") + encoded_image = cv2.imencode(".jpg", cv_image)[1].tobytes() + return encoded_image + + def _prepare_compressed_image(self, msg): + if "jpg" in msg.format or "jpeg" in msg.format: + content_type = "image/jpg" + elif "png" in msg.format: + content_type = "image/png" + else: + self._logger.warn("Image format", msg.format, "not supported") + return + return {"value": bytes(msg.data), "content_type": content_type} + + def _prepare_attr_data(self, msg): + if hasattr(msg, "data"): + msg = msg.data + return msg diff --git a/formant_ros2_adapter/scripts/components/subscriber/basic_subscriber_coodinator.py b/formant_ros2_adapter/scripts/components/subscriber/basic_subscriber_coodinator.py index a018763..ba24415 100644 --- a/formant_ros2_adapter/scripts/components/subscriber/basic_subscriber_coodinator.py +++ b/formant_ros2_adapter/scripts/components/subscriber/basic_subscriber_coodinator.py @@ -11,6 +11,7 @@ from configuration.config_schema import ConfigSchema from configuration.subscriber_config import SubscriberConfig, MessagePathConfig from .ingester import Ingester +from .batched_ingester import BatchIngester from ros2_utils.qos import QOS_PROFILES, qos_profile_system_default from ros2_utils.topic_type_provider import TopicTypeProvider from utils.logger import get_logger @@ -30,7 +31,7 @@ def __init__( self, fclient: Client, node: Node, - ingester: Ingester, + ingester: BatchIngester, topic_type_provider: TopicTypeProvider, ): self._fclient = fclient diff --git a/formant_ros2_adapter/scripts/components/subscriber/batched_ingester.py b/formant_ros2_adapter/scripts/components/subscriber/batched_ingester.py new file mode 100644 index 0000000..24c61bd --- /dev/null +++ b/formant_ros2_adapter/scripts/components/subscriber/batched_ingester.py @@ -0,0 +1,69 @@ +from .base_ingester import BaseIngester +from formant.protos.agent.v1 import agent_pb2 +from formant.protos.model.v1 import datapoint_pb2 +from formant.sdk.agent.v1 import Client +from queue import LifoQueue +from typing import Dict, List +import threading +import time + +MAX_INGEST_SIZE = 10 + + +class BatchIngester(BaseIngester): + def __init__( + self, _fclient: Client, ingest_interval: int = 0.04, num_threads: int = 2 + ): + super(BatchIngester, self).__init__(_fclient) + self._stream_queues: Dict[str, LifoQueue] = {} + self._ingest_interval = ingest_interval + self._num_threads = num_threads + self._threads: List[threading.Thread] = [] + self._terminate_flag = False + + self._start() + + def ingest( + self, + msg, + msg_type: type, + formant_stream: str, + topic: str, + msg_timestamp: int, + tags: Dict, + ): + message = self.prepare( + msg, msg_type, formant_stream, topic, msg_timestamp, tags + ) + has_stream = formant_stream in self._stream_queues + if not has_stream: + self._stream_queues[formant_stream] = LifoQueue() + + self._stream_queues[formant_stream].put(message) + + def _ingest_once(self): + + for _, queue in self._stream_queues.items(): + ingest_size = min(queue.qsize(), MAX_INGEST_SIZE) + datapoints = [queue.get() for _ in range(ingest_size)] + + self._fclient.post_data_multi(datapoints) + + def _ingest_continually(self): + while not self._terminate_flag: + self._ingest_once() + time.sleep(self._ingest_interval) + + def _start(self): + self._terminate_flag = False + for i in range(self._num_threads): + self._threads.append( + threading.Thread( + target=self._ingest_continually, + daemon=True, + ) + ) + self._threads[i].start() + + def terminate(self): + self._terminate_flag = True diff --git a/formant_ros2_adapter/scripts/components/subscriber/ingester.py b/formant_ros2_adapter/scripts/components/subscriber/ingester.py index 46794bd..7df8658 100644 --- a/formant_ros2_adapter/scripts/components/subscriber/ingester.py +++ b/formant_ros2_adapter/scripts/components/subscriber/ingester.py @@ -6,27 +6,10 @@ BatteryState, CompressedImage, Image, - Joy, LaserScan, NavSatFix, PointCloud2, ) -from std_msgs.msg import ( - Bool, - Char, - String, - Float32, - Float64, - Int8, - Int16, - Int32, - Int64, - UInt8, - UInt16, - UInt32, - UInt64, -) - from formant.sdk.agent.v1 import Client from formant.protos.model.v1.datapoint_pb2 import Datapoint from formant.sdk.agent.v1.localization.types import ( @@ -46,14 +29,11 @@ message_to_json, get_message_path_value, ) +from .base_ingester import BaseIngester +from .types import STRING_TYPES, BOOL_TYPES, NUMERIC_TYPES -class Ingester: - def __init__(self, _fclient: Client): - self._fclient = _fclient - self.cv_bridge = CvBridge() - self._logger = get_logger() - +class Ingester(BaseIngester): def ingest( self, msg, @@ -63,151 +43,12 @@ def ingest( msg_timestamp: int, tags: Dict, ): + msg = self.prepare(msg, msg_type, formant_stream, topic, msg_timestamp, tags) - # Handle the message based on its type try: - if msg_type in [str, String, Char]: - if hasattr(msg, "data"): - msg = msg.data - - self._fclient.post_text( - formant_stream, - str(msg), - tags=tags, - timestamp=msg_timestamp, - ) - - elif msg_type in [Bool, bool]: - if hasattr(msg, "data"): - msg = msg.data - - self._fclient.post_bitset( - formant_stream, - {topic: msg}, - tags=tags, - timestamp=msg_timestamp, - ) - - elif msg_type in [ - int, - float, - Float32, - Float64, - Int8, - Int16, - Int32, - Int64, - UInt8, - UInt16, - UInt32, - UInt64, - ]: - if hasattr(msg, "data"): - msg = msg.data - - self._fclient.post_numeric( - formant_stream, - msg, - tags=tags, - timestamp=msg_timestamp, - ) - - elif msg_type == NavSatFix: - # Convert NavSatFix to a Formant location - self._fclient.post_geolocation( - stream=formant_stream, - latitude=msg.latitude, - longitude=msg.longitude, - altitude=msg.altitude, - tags=tags, - timestamp=msg_timestamp, - ) - - elif msg_type == Image: - # Convert Image to a Formant image - cv_image = self.cv_bridge.imgmsg_to_cv2(msg, "bgr8") - encoded_image = cv2.imencode(".jpg", cv_image)[1].tobytes() - - self._fclient.post_image( - stream=formant_stream, - value=encoded_image, - tags=tags, - timestamp=msg_timestamp, - ) - - elif msg_type == CompressedImage: - # Post the compressed image - if "jpg" in msg.format or "jpeg" in msg.format: - content_type = "image/jpg" - elif "png" in msg.format: - content_type = "image/png" - else: - self._logger.warn("Image format", msg.format, "not supported") - return - self._fclient.post_image( - formant_stream, - value=bytes(msg.data), - content_type=content_type, - tags=tags, - timestamp=msg_timestamp, - ) - - elif msg_type == BatteryState: - self._fclient.post_battery( - formant_stream, - msg.percentage, - voltage=msg.voltage, - current=msg.current, - charge=msg.charge, - tags=tags, - timestamp=msg_timestamp, - ) - - elif msg_type == LaserScan: - # Convert LaserScan to a Formant pointcloud - try: - self._fclient.agent_stub.PostData( - Datapoint( - stream=formant_stream, - point_cloud=FPointCloud.from_ros_laserscan(msg).to_proto(), - tags=tags, - timestamp=msg_timestamp, - ) - ) - except grpc.RpcError as e: - return - except Exception as e: - self._logger.error( - "Could not ingest " + formant_stream + ": " + str(e) - ) - return - - elif msg_type == PointCloud2: - try: - self._fclient.agent_stub.PostData( - Datapoint( - stream=formant_stream, - point_cloud=FPointCloud.from_ros(msg).to_proto(), - tags=tags, - timestamp=msg_timestamp, - ) - ) - except grpc.RpcError as e: - return - except Exception as e: - self._logger.error( - "Could not ingest " + formant_stream + ": " + str(e) - ) - return - - else: - # Ingest any messages without a direct mapping to a Formant type as JSON - self._fclient.post_json( - formant_stream, - message_to_json(msg), - tags=tags, - timestamp=msg_timestamp, - ) - - except AttributeError as e: + self._fclient.post_data(msg) + except grpc.RpcError as e: + return + except Exception as e: self._logger.error("Could not ingest " + formant_stream + ": " + str(e)) + return diff --git a/formant_ros2_adapter/scripts/components/subscriber/subscriber_coordinator.py b/formant_ros2_adapter/scripts/components/subscriber/subscriber_coordinator.py index c5bea3c..c783d90 100644 --- a/formant_ros2_adapter/scripts/components/subscriber/subscriber_coordinator.py +++ b/formant_ros2_adapter/scripts/components/subscriber/subscriber_coordinator.py @@ -5,6 +5,7 @@ from .basic_subscriber_coodinator import BasicSubscriberCoordinator from configuration.config_schema import ConfigSchema from .ingester import Ingester +from .batched_ingester import BatchIngester from .localization_subscriber_coodinator import LocalizationSubscriberCoordinator from .numeric_set_subscriber_coodinator import NumericSetSubscriberCoordinator from ros2_utils.topic_type_provider import TopicTypeProvider @@ -17,7 +18,7 @@ def __init__( ): self._logger = get_logger() self._fclient = fclient - self._ingester = Ingester(self._fclient) + self._ingester = self._choose_ingester() self._node = node self._topic_type_provider = topic_type_provider self._basic_subscriber_coodinator = BasicSubscriberCoordinator( @@ -36,3 +37,11 @@ def setup_with_config(self, config: ConfigSchema): self._localization_subscriber_coordinator.setup_with_config(config) self._numeric_set_subscriber_coodinator.setup_with_config(config) self._logger.info("Set up Subscriber Coordinator") + + def _choose_ingester(self): + has_batch_ingester = hasattr(self._fclient, "post_data_multi") and callable( + self._fclient.post_data_multi + ) + if has_batch_ingester: + return BatchIngester(self._fclient) + return Ingester(self._fclient) diff --git a/formant_ros2_adapter/scripts/components/subscriber/types.py b/formant_ros2_adapter/scripts/components/subscriber/types.py new file mode 100644 index 0000000..98b222d --- /dev/null +++ b/formant_ros2_adapter/scripts/components/subscriber/types.py @@ -0,0 +1,40 @@ +from std_msgs.msg import ( + Bool, + Char, + String, + Float32, + Float64, + Int8, + Int16, + Int32, + Int64, + UInt8, + UInt16, + UInt32, + UInt64, +) +from sensor_msgs.msg import ( + BatteryState, + LaserScan, + NavSatFix, + PointCloud2, +) + +STRING_TYPES = [str, String, Char] +BOOL_TYPES = [Bool, bool] +NUMERIC_TYPES = [ + int, + float, + Float32, + Float64, + Int8, + Int16, + Int32, + Int64, + UInt8, + UInt16, + UInt32, + UInt64, +] + +OTHER_DATA_TYPES = [NavSatFix, BatteryState, LaserScan, PointCloud2] diff --git a/formant_ros2_adapter/scripts/config.json b/formant_ros2_adapter/scripts/config.json index 9e26dfe..c5b9c3e 100644 --- a/formant_ros2_adapter/scripts/config.json +++ b/formant_ros2_adapter/scripts/config.json @@ -1 +1,31 @@ -{} \ No newline at end of file +{ + "ros2_adapter_configuration": { + "subscribers": [ + { + "ros2_topic": "/my_string", + "ros2_message_type": "std_msgs/msg/String", + "formant_stream": "my.string" + }, + { + "ros2_topic": "/my_velocity", + "ros2_message_type": "geometry_msgs/msg/Twist", + "formant_stream": "my.velocity.linear", + "ros2_message_paths": [ + { + "path": "linear" + } + ] + }, + { + "ros2_topic": "/my_velocity", + "formant_stream": "my.velocity.angular", + "ros2_message_type": "geometry_msgs/msg/Twist", + "ros2_message_paths": [ + { + "path": "angular" + } + ] + } + ] + } +} \ No newline at end of file diff --git a/formant_ros2_adapter/scripts/tests/test_subscribers_batched.py b/formant_ros2_adapter/scripts/tests/test_subscribers_batched.py new file mode 100644 index 0000000..346fc1b --- /dev/null +++ b/formant_ros2_adapter/scripts/tests/test_subscribers_batched.py @@ -0,0 +1,43 @@ +import unittest +from formant.sdk.agent.v1 import Client +from components.subscriber.batched_ingester import ( + BatchIngester, +) # Replace with the actual import +from queue import Empty +import time + + +class TestBatchIngester(unittest.TestCase): + def setUp(self): + self.fclient = Client() + self.ingester = BatchIngester(self.fclient, ingest_interval=1, num_threads=1) + + def test_message_ingest(self): + self.ingester.ingest("msg", str, "stream1", "topic", 12345, {}) + + # Checking if the message is added to the correct stream queue + self.assertEqual(self.ingester._stream_queues["stream1"].qsize(), 1) + + def test_queue_size_limit(self): + for i in range(15): # Adding 15 messages + self.ingester.ingest(f"msg{i}", str, "stream1", "topic", 12345, {}) + + # Ingest once + self.ingester._ingest_once() + + # Checking the remaining queue size (should be 15 - MAX_INGEST_SIZE) + self.assertEqual(self.ingester._stream_queues["stream1"].qsize(), 5) + + def test_queue_flush(self): + for i in range(5): # Adding 5 messages + self.ingester.ingest(f"msg{i}", str, "stream1", "topic", 12345, {}) + + # Wait for the ingest interval to pass (plus a small buffer) + time.sleep(1.2) + + # The queue should be empty after one ingest interval + self.assertRaises(Empty, self.ingester._stream_queues["stream1"].get_nowait) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_subscribers/test_subscribers_batched.sh b/tests/test_subscribers/test_subscribers_batched.sh new file mode 100755 index 0000000..57eb7b4 --- /dev/null +++ b/tests/test_subscribers/test_subscribers_batched.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +DEVICE="nicolas-container" +TOPIC="/my_string" +STREAM="my.string" + +source /opt/ros/*/setup.bash + +# How to get date formant required for fctl +start_time=$(date -u +"%Y-%m-%dT%H:%M:%S") + +# Publish multiple messages in quick succession to trigger batching +number_of_messages=20 +for i in $(seq 1 $number_of_messages); do + ros2 topic pub -t 2 -w 0 $TOPIC std_msgs/msg/String "data: {key: value_$i}" & + sleep 0.01 # Sleep for 10ms between messages +done + +wait # Wait for all background jobs to finish + +end_time=$(date -u +"%Y-%m-%dT%H:%M:%S") + +# Give some time for the BatchIngester to process and send the data +# This time should be greater than the ingest_interval in BatchIngester +sleep 1 \ No newline at end of file