Skip to content

Commit b4a1660

Browse files
committed
implement __set__ for ContinuousObservable to properly initialize state
The ContinuousObservable descriptor was inheriting Observable's __set__ method, which stored raw numeric values instead of ContinuousState objects. This caused AttributeError when __get__ tried to access state.last_update on a float/numpy.float64. The new __set__ method creates a ContinuousState wrapper on first assignment and updates the existing state on subsequent assignments. This ensures the private attribute always contains a properly structured state object with value, last_update, rate_func, and thresholds attributes. Fixes initialization of continuous observables in agent __init__ methods where energy and other time-varying properties are set.
1 parent 7eefaf3 commit b4a1660

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

mesa/experimental/mesa_signals/mesa_signal.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,40 @@ def __init__(self, initial_value: float, rate_func: Callable):
489489
self._rate_func = rate_func
490490
self._thresholds = {} # threshold_value -> callback
491491

492+
def __set__(self, instance: HasObservables, value):
493+
"""Set the value, ensuring we store a ContinuousState."""
494+
# Get or create state
495+
state = getattr(instance, self.private_name, None)
496+
497+
if state is None:
498+
# First time - create ContinuousState
499+
state = ContinuousState(
500+
value=float(value),
501+
last_update=self._get_time(instance),
502+
rate_func=self._rate_func,
503+
thresholds=self._thresholds,
504+
)
505+
setattr(instance, self.private_name, state)
506+
else:
507+
# Update existing - just change the value and reset timestamp
508+
old_value = state.value
509+
state.value = float(value)
510+
state.last_update = self._get_time(instance)
511+
512+
# Notify changes
513+
instance.notify(self.public_name, old_value, state.value, "change")
514+
515+
# Check thresholds
516+
for threshold, direction in state.check_thresholds(old_value, state.value):
517+
instance.notify(
518+
self.public_name,
519+
old_value,
520+
state.value,
521+
"threshold_crossed",
522+
threshold=threshold,
523+
direction=direction,
524+
)
525+
492526
def __get__(self, instance: HasObservables, owner):
493527
"""Lazy evaluation - compute current value based on elapsed time."""
494528
if instance is None:

0 commit comments

Comments
 (0)