Skip to content

Commit f635d90

Browse files
committed
fix: Make ContinuousObservable thresholds instance-specific
BREAKING: Thresholds are now stored per-instance instead of per-class. Previously, ContinuousObservable stored thresholds at the class level in a shared `_thresholds` set. This caused all instances to check the same threshold values, preventing agents from having individual thresholds based on their specific parameters (e.g., a wolf with 100 starting energy couldn't have a different critical threshold than one with 50 starting energy). Changes: - Move threshold storage from descriptor to ContinuousState instances - Change threshold structure from set to dict: {threshold_value: set(callbacks)} - Update add_threshold() to access and modify instance-level state - Ensure threshold initialization happens before adding thresholds - Add direction parameter to all threshold_crossed signal emissions This allows each agent to maintain its own set of thresholds, enabling instance-specific reactive behaviors like `self.add_threshold("energy", self.starting_energy * 0.25, callback)`.
1 parent 81f52ad commit f635d90

File tree

1 file changed

+29
-19
lines changed

1 file changed

+29
-19
lines changed

mesa/experimental/mesa_signals/mesa_signal.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -475,22 +475,35 @@ def _mesa_notify(self, signal: AttributeDict):
475475
def add_threshold(self, observable_name: str, threshold: float, callback: Callable):
476476
"""Convenience method for adding thresholds."""
477477
obs = getattr(type(self), observable_name)
478-
if isinstance(obs, ContinuousObservable):
479-
obs._thresholds.add(threshold)
478+
if not isinstance(obs, ContinuousObservable):
479+
raise ValueError(f"{observable_name} is not a ContinuousObservable")
480480

481-
# Check if callback is already subscribed
482-
existing_subscribers = self.subscribers.get(observable_name, {}).get(
483-
"threshold_crossed", []
484-
)
485-
already_subscribed = any(
486-
ref() == callback for ref in existing_subscribers if ref() is not None
487-
)
481+
# Get the instance's ContinuousState
482+
state = getattr(self, obs.private_name, None)
483+
if state is None:
484+
# State not yet created - will be created on first access/set
485+
# We need to ensure the observable is initialized first
486+
_ = getattr(self, observable_name) # Trigger initialization
487+
state = getattr(self, obs.private_name)
488+
489+
# Add threshold to the instance's state
490+
if threshold not in state.thresholds:
491+
state.thresholds[threshold] = set()
492+
493+
# Add callback to this threshold's callback set
494+
state.thresholds[threshold].add(callback)
495+
496+
# Subscribe to the threshold_crossed signal
497+
# Check if callback is already subscribed to avoid duplicates
498+
existing_subscribers = self.subscribers.get(observable_name, {}).get(
499+
"threshold_crossed", []
500+
)
501+
already_subscribed = any(
502+
ref() == callback for ref in existing_subscribers if ref() is not None
503+
)
488504

489-
# Only subscribe if not already subscribed
490-
if not already_subscribed:
491-
self.observe(observable_name, "threshold_crossed", callback)
492-
else:
493-
raise ValueError(f"{observable_name} is not a ContinuousObservable")
505+
if not already_subscribed:
506+
self.observe(observable_name, "threshold_crossed", callback)
494507

495508

496509
class ContinuousObservable(Observable):
@@ -501,7 +514,6 @@ def __init__(self, initial_value: float, rate_func: Callable):
501514
super().__init__(fallback_value=initial_value)
502515
self.signal_types.add("threshold_crossed")
503516
self._rate_func = rate_func
504-
self._thresholds = set()
505517

506518
def __set__(self, instance: HasObservables, value):
507519
"""Set the value, ensuring we store a ContinuousState."""
@@ -514,7 +526,6 @@ def __set__(self, instance: HasObservables, value):
514526
value=float(value),
515527
last_update=self._get_time(instance),
516528
rate_func=self._rate_func,
517-
thresholds=self._thresholds,
518529
)
519530
setattr(instance, self.private_name, state)
520531
else:
@@ -552,7 +563,6 @@ def __get__(self, instance: HasObservables, owner):
552563
value=self.fallback_value,
553564
last_update=current_time,
554565
rate_func=self._rate_func,
555-
thresholds=self._thresholds,
556566
)
557567
setattr(instance, self.private_name, state)
558568

@@ -615,12 +625,12 @@ class ContinuousState:
615625
__slots__ = ["last_update", "rate_func", "thresholds", "value"]
616626

617627
def __init__(
618-
self, value: float, last_update: float, rate_func: Callable, thresholds: dict
628+
self, value: float, last_update: float, rate_func: Callable
619629
):
620630
self.value = value
621631
self.last_update = last_update
622632
self.rate_func = rate_func
623-
self.thresholds = thresholds
633+
self.thresholds = {} # {threshold_value: set(callbacks)}
624634

625635
def calculate(self, elapsed: float, instance: Any) -> float:
626636
"""Calculate new value based on elapsed time.

0 commit comments

Comments
 (0)