Skip to content

Commit 7eefaf3

Browse files
committed
Add time-source fallback mechanism for ContinuousObservable
Implement flexible time retrieval in ContinuousObservable to work with different Mesa time management approaches. The _get_time() method tries multiple sources in priority order: 1. model.simulator.time (DEVS/continuous models) 2. model.time (if explicitly set) 3. model.steps (fallback for discrete models) This workaround enables ContinuousObservable to function correctly regardless of whether a model uses discrete event simulation, custom time tracking, or standard step-based progression. Obviously this should be fixed structurally. See #2228
1 parent 5597615 commit 7eefaf3

File tree

1 file changed

+35
-8
lines changed

1 file changed

+35
-8
lines changed

mesa/experimental/mesa_signals/mesa_signal.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -498,16 +498,18 @@ def __get__(self, instance: HasObservables, owner):
498498
state = getattr(instance, self.private_name, None)
499499
if state is None:
500500
# First access - initialize
501+
# Use simulator time if available, otherwise fall back to steps
502+
current_time = self._get_time(instance)
501503
state = ContinuousState(
502504
value=self.fallback_value,
503-
last_update=instance.model.time,
505+
last_update=current_time,
504506
rate_func=self._rate_func,
505507
thresholds=self._thresholds,
506508
)
507509
setattr(instance, self.private_name, state)
508510

509511
# Calculate new value based on time
510-
current_time = instance.model.time
512+
current_time = self._get_time(instance)
511513
elapsed = current_time - state.last_update
512514

513515
if elapsed > 0:
@@ -542,26 +544,51 @@ def __get__(self, instance: HasObservables, owner):
542544

543545
return state.value
544546

547+
# TODO: A universal truth for time should be implemented structurally in Mesa. See https://github.com/projectmesa/mesa/discussions/2228
548+
def _get_time(self, instance):
549+
"""Get current time from model, trying multiple sources."""
550+
model = instance.model
551+
552+
# Try simulator time first (for DEVS models)
553+
if hasattr(model, "simulator") and hasattr(model.simulator, "time"):
554+
return model.simulator.time
555+
556+
# Fall back to model.time if it exists
557+
if hasattr(model, "time"):
558+
return model.time
559+
560+
# Last resort: use steps as a proxy for time
561+
return float(model.steps)
562+
545563

546564
class ContinuousState:
547565
"""Internal state tracker for continuous observables."""
548566

549567
__slots__ = ["last_update", "rate_func", "thresholds", "value"]
550568

551-
def __init__(self, value, last_update, rate_func, thresholds):
569+
def __init__(
570+
self, value: float, last_update: float, rate_func: Callable, thresholds: dict
571+
):
552572
self.value = value
553573
self.last_update = last_update
554574
self.rate_func = rate_func
555575
self.thresholds = thresholds
556576

557-
def calculate(self, elapsed: float, instance) -> float:
558-
"""Calculate new value based on elapsed time."""
577+
def calculate(self, elapsed: float, instance: Any) -> float:
578+
"""Calculate new value based on elapsed time.
579+
580+
Uses simple linear integration for now. Could be extended
581+
to support more complex integration methods.
582+
"""
559583
rate = self.rate_func(self.value, elapsed, instance)
560-
# Simple linear integration for now
561584
return self.value + (rate * elapsed)
562585

563-
def check_thresholds(self, old_value, new_value):
564-
"""Check if any thresholds were crossed."""
586+
def check_thresholds(self, old_value: float, new_value: float) -> list:
587+
"""Check if any thresholds were crossed.
588+
589+
Returns:
590+
List of (threshold_value, direction) tuples for crossed thresholds
591+
"""
565592
crossed = []
566593
for threshold_value in self.thresholds:
567594
# Crossed upward

0 commit comments

Comments
 (0)