Skip to content

Commit 2b059ff

Browse files
committed
Fix threshold subscription to prevent duplicate callbacks
Changes the ContinuousObservable threshold system from storing callbacks directly to using a signal-based subscription pattern: - Change `_thresholds` from dict to set (stores threshold values only) - Update `add_threshold()` to check for existing subscriptions before subscribing to prevent duplicate callback invocations - Separate concerns: thresholds define WHICH values to watch, observers define WHO to notify when crossed - Add ValueError when attempting to add threshold to non-ContinuousObservable This fixes an issue where registering multiple thresholds with the same callback would cause it to be called multiple times per threshold crossing.
1 parent b4a1660 commit 2b059ff

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

mesa/experimental/mesa_signals/mesa_signal.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,21 @@ def add_threshold(self, observable_name: str, threshold: float, callback: Callab
476476
"""Convenience method for adding thresholds."""
477477
obs = getattr(type(self), observable_name)
478478
if isinstance(obs, ContinuousObservable):
479-
obs._thresholds[threshold] = callback
479+
obs._thresholds.add(threshold)
480+
481+
# Check if callback is already subscribed
482+
existing_subscribers = self.subscribers.get(observable_name, {}).get(
483+
"threshold_crossed", []
484+
)
485+
already_subscribed = any(
486+
ref() == callback for ref in existing_subscribers if ref() is not None
487+
)
488+
489+
# Only subscribe if not already subscribed
490+
if not already_subscribed:
491+
self.observe(observable_name, "threshold_crossed", callback)
492+
else:
493+
raise ValueError(f"{observable_name} is not a ContinuousObservable")
480494

481495

482496
class ContinuousObservable(Observable):
@@ -487,7 +501,7 @@ def __init__(self, initial_value: float, rate_func: Callable):
487501
super().__init__(fallback_value=initial_value)
488502
self.signal_types.add("threshold_crossed")
489503
self._rate_func = rate_func
490-
self._thresholds = {} # threshold_value -> callback
504+
self._thresholds = set()
491505

492506
def __set__(self, instance: HasObservables, value):
493507
"""Set the value, ensuring we store a ContinuousState."""
@@ -624,13 +638,13 @@ def check_thresholds(self, old_value: float, new_value: float) -> list:
624638
List of (threshold_value, direction) tuples for crossed thresholds
625639
"""
626640
crossed = []
627-
for threshold_value in self.thresholds:
641+
for threshold in self.thresholds:
628642
# Crossed upward
629-
if old_value < threshold_value <= new_value:
630-
crossed.append((threshold_value, "up"))
643+
if old_value < threshold <= new_value:
644+
crossed.append((threshold, "up"))
631645
# Crossed downward
632-
elif new_value <= threshold_value < old_value:
633-
crossed.append((threshold_value, "down"))
646+
elif new_value <= threshold < old_value:
647+
crossed.append((threshold, "down"))
634648
return crossed
635649

636650

0 commit comments

Comments
 (0)