Skip to content

Commit a68bd90

Browse files
made scheduling + comp runtime topic refactoring changes
1 parent d339394 commit a68bd90

File tree

2 files changed

+69
-14
lines changed

2 files changed

+69
-14
lines changed

formant_ros2_adapter/scripts/components/subscriber/basic_subscriber_coodinator.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from configuration.config_schema import ConfigSchema
1212
from configuration.subscriber_config import SubscriberConfig, MessagePathConfig
1313
from .ingester import Ingester
14+
from .decorators import lock_decorator, handle_value_error
1415
from ros2_utils.qos import QOS_PROFILES, qos_profile_system_default
1516
from ros2_utils.topic_type_provider import TopicTypeProvider
1617
from utils.logger import get_logger
@@ -23,6 +24,7 @@
2324
FORMANT_OVERRIDE_TIMESTAMP = (
2425
os.getenv("FORMANT_OVERRIDE_TIMESTAMP", "").lower() == "true"
2526
)
27+
SCHEDULE_SUBSCRIPTIONS_INTERVAL = 1.0
2628

2729

2830
class BasicSubscriberCoordinator:
@@ -38,22 +40,44 @@ def __init__(
3840
self._ingester = ingester
3941
self._topic_type_provider = topic_type_provider
4042
self._subscriptions: Dict[str, List[Subscription]] = {}
43+
self._queued_topics: Dict = {}
4144
self._logger = get_logger()
4245
self._config_lock = threading.RLock()
46+
self._subscribe_lock = threading.Lock()
4347

48+
@lock_decorator("_config_lock")
4449
def setup_with_config(self, config: ConfigSchema):
45-
with self._config_lock:
46-
self._config = config
47-
self._cleanup()
48-
if self._config.subscribers:
49-
for subscriber_config in self._config.subscribers:
50-
try:
51-
self._setup_subscription_for_config(subscriber_config)
52-
except ValueError as value_error:
53-
self._logger.warn(value_error)
54-
continue
55-
56-
def _setup_subscription_for_config(self, subscriber_config: SubscriberConfig):
50+
self._config = config
51+
self._cleanup()
52+
53+
has_subscribers = self._config.subscribers
54+
if not has_subscribers:
55+
return
56+
57+
self._queued_topics = {
58+
subscriber.topic: subscriber for subscriber in self._config.subscribers
59+
}
60+
61+
self._schedule_subscriptions()
62+
63+
def _schedule_subscriptions(self):
64+
t = threading.Timer(SCHEDULE_SUBSCRIPTIONS_INTERVAL, self._setup_subscribers)
65+
t.daemon = True
66+
t.start()
67+
68+
@lock_decorator("_subscriber_lock")
69+
def _setup_subscribers(self):
70+
active_topics = set(self._get_active_topics())
71+
72+
backlog_topics = active_topics.intersection(self._queued_topics.keys)
73+
74+
for topic in backlog_topics:
75+
topic_subscriber = self._queued_topics[topic]
76+
self._setup_subscription(topic_subscriber)
77+
_ = self._queued_topics.pop(topic)
78+
79+
@handle_value_error
80+
def _setup_subscription(self, subscriber_config: SubscriberConfig):
5781
topic = subscriber_config.topic
5882
qos_profile = QOS_PROFILES.get(
5983
subscriber_config.qos_profile, qos_profile_system_default
@@ -62,8 +86,6 @@ def _setup_subscription_for_config(self, subscriber_config: SubscriberConfig):
6286
ros2_type = subscriber_config.message_type
6387
if ros2_type is None:
6488
ros2_type = self._topic_type_provider.get_type_for_topic(topic)
65-
if ros2_type is None:
66-
raise ValueError("No ROS2 type found for %s" % topic)
6789

6890
self._logger.debug(
6991
"Setting up subscription %s, %s, %s"
@@ -82,6 +104,13 @@ def _setup_subscription_for_config(self, subscriber_config: SubscriberConfig):
82104
self._subscriptions[topic] = []
83105
self._subscriptions[topic].append(new_subscriber)
84106

107+
def _get_active_topics(self):
108+
return [
109+
topic_name
110+
for topic_name, _ in self._node.get_topic_names_and_types()
111+
if self._node.get_publishers_info_by_topic(topic_name)
112+
]
113+
85114
def _handle_message(
86115
self,
87116
msg,
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import functools
2+
import threading
3+
4+
5+
def lock_decorator(lock_var):
6+
def decorator(func):
7+
@functools.wraps(func)
8+
def wrapped(self, *args, **kwargs):
9+
lock = getattr(self, lock_var)
10+
with lock:
11+
return func(self, *args, **kwargs)
12+
13+
return wrapped
14+
15+
return decorator
16+
17+
18+
def handle_value_error(func):
19+
def wrapper(*args, **kwargs):
20+
try:
21+
return func(*args, **kwargs)
22+
except ValueError as value_error:
23+
args[0]._logger.warn(value_error)
24+
return None
25+
26+
return wrapper

0 commit comments

Comments
 (0)