Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from configuration.config_schema import ConfigSchema
from configuration.subscriber_config import SubscriberConfig, MessagePathConfig
from .ingester import Ingester
from .decorators import lock_decorator, handle_value_error
from ros2_utils.qos import QOS_PROFILES, qos_profile_system_default
from ros2_utils.topic_type_provider import TopicTypeProvider
from utils.logger import get_logger
Expand All @@ -23,6 +24,7 @@
FORMANT_OVERRIDE_TIMESTAMP = (
os.getenv("FORMANT_OVERRIDE_TIMESTAMP", "").lower() == "true"
)
SCHEDULE_SUBSCRIPTIONS_INTERVAL = 1.0


class BasicSubscriberCoordinator:
Expand All @@ -38,22 +40,44 @@ def __init__(
self._ingester = ingester
self._topic_type_provider = topic_type_provider
self._subscriptions: Dict[str, List[Subscription]] = {}
self._queued_topics: Dict = {}
self._logger = get_logger()
self._config_lock = threading.RLock()
self._subscribe_lock = threading.Lock()

@lock_decorator("_config_lock")
def setup_with_config(self, config: ConfigSchema):
with self._config_lock:
self._config = config
self._cleanup()
if self._config.subscribers:
for subscriber_config in self._config.subscribers:
try:
self._setup_subscription_for_config(subscriber_config)
except ValueError as value_error:
self._logger.warn(value_error)
continue

def _setup_subscription_for_config(self, subscriber_config: SubscriberConfig):
self._config = config
self._cleanup()

has_subscribers = self._config.subscribers
if not has_subscribers:
return

self._queued_topics = {
subscriber.topic: subscriber for subscriber in self._config.subscribers
}

self._schedule_subscriptions()

def _schedule_subscriptions(self):
t = threading.Timer(SCHEDULE_SUBSCRIPTIONS_INTERVAL, self._setup_subscribers)
t.daemon = True
t.start()

@lock_decorator("_subscriber_lock")
def _setup_subscribers(self):
active_topics = set(self._get_active_topics())

backlog_topics = active_topics.intersection(self._queued_topics.keys)

for topic in backlog_topics:
topic_subscriber = self._queued_topics[topic]
self._setup_subscription(topic_subscriber)
_ = self._queued_topics.pop(topic)

@handle_value_error
def _setup_subscription(self, subscriber_config: SubscriberConfig):
topic = subscriber_config.topic
qos_profile = QOS_PROFILES.get(
subscriber_config.qos_profile, qos_profile_system_default
Expand All @@ -62,8 +86,6 @@ def _setup_subscription_for_config(self, subscriber_config: SubscriberConfig):
ros2_type = subscriber_config.message_type
if ros2_type is None:
ros2_type = self._topic_type_provider.get_type_for_topic(topic)
if ros2_type is None:
raise ValueError("No ROS2 type found for %s" % topic)

self._logger.debug(
"Setting up subscription %s, %s, %s"
Expand All @@ -82,6 +104,13 @@ def _setup_subscription_for_config(self, subscriber_config: SubscriberConfig):
self._subscriptions[topic] = []
self._subscriptions[topic].append(new_subscriber)

def _get_active_topics(self):
return [
topic_name
for topic_name, _ in self._node.get_topic_names_and_types()
if self._node.get_publishers_info_by_topic(topic_name)
]

def _handle_message(
self,
msg,
Expand Down
26 changes: 26 additions & 0 deletions formant_ros2_adapter/scripts/components/subscriber/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import functools
import threading


def lock_decorator(lock_var):
def decorator(func):
@functools.wraps(func)
def wrapped(self, *args, **kwargs):
lock = getattr(self, lock_var)
with lock:
return func(self, *args, **kwargs)

return wrapped

return decorator


def handle_value_error(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except ValueError as value_error:
args[0]._logger.warn(value_error)
return None

return wrapper