Skip to content

Commit f298c1a

Browse files
committed
Add ModelSystemListeners and EventHandlers
1 parent 0c909e2 commit f298c1a

File tree

1 file changed

+205
-0
lines changed

1 file changed

+205
-0
lines changed
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
from abc import ABC
2+
from typing import TYPE_CHECKING, Dict, Union
3+
from sys import gettrace
4+
from utils import log
5+
import numpy as np
6+
from pathlib import Path
7+
import importlib.util
8+
9+
from demand.trips import DemandModel
10+
if TYPE_CHECKING:
11+
from datahandling.zonedata import ZoneData
12+
from modelsystem import ModelSystem
13+
from datatypes.purpose import TourPurpose
14+
from datatypes.demand import Demand
15+
from assignment.departure_time import DepartureTimeModel
16+
from assignment.abstract_assignment import Period
17+
import pandas as pd
18+
19+
class ModelSystemEventListener(ABC):
20+
21+
def __init__(self):
22+
pass
23+
24+
def on_zone_data_loaded(self, base_data: 'ZoneData', forecast_data: 'ZoneData') -> None:
25+
"""
26+
Event handler that is called when zone data is loaded.
27+
28+
Args:
29+
data (ZoneData): The loaded zone data.
30+
"""
31+
pass
32+
33+
def on_model_system_initialized(self, model_system: 'ModelSystem') -> None:
34+
"""
35+
Event handler that is called when the model system is initialized.
36+
37+
Args:
38+
model_system (ModelSystem): The model system.
39+
"""
40+
pass
41+
42+
def on_iteration_started(self, iteration: Union[int, str], previous_impedance: Dict[str, Dict[str, np.ndarray]]) -> None:
43+
"""
44+
Event handler that is called when an iteration is started.
45+
46+
Args:
47+
iteration (int | str): The iteration number.
48+
"""
49+
pass
50+
51+
def on_car_density_updated(self, iteration: Union[int, str], prediction: 'pd.Series' ) -> None:
52+
"""
53+
Event handler that is called when car density is updated.
54+
55+
Args:
56+
iteration (int | str): The iteration number.
57+
prediction (pandas.Series): The updated car density prediction.
58+
"""
59+
pass
60+
61+
def on_base_demand_assigned(self, impedance: Dict[str, Dict[str, np.ndarray]]) -> None:
62+
"""
63+
Event handler that is called when base demand has been assigned.
64+
65+
Args:
66+
impedance (dict): The impedance matrices.
67+
"""
68+
pass
69+
70+
def on_population_segments_created(self, dm: DemandModel) -> None:
71+
"""
72+
Event handler that is called when population segments have been created.
73+
74+
Args:
75+
dm (DemandModel): The demand model.
76+
"""
77+
pass
78+
79+
def on_demand_model_tours_generated(self, dm: 'DemandModel') -> None:
80+
"""
81+
Event handler that is called when demand model tours have been generated.
82+
83+
Args:
84+
dm (DemandModel): The demand model.
85+
"""
86+
pass
87+
88+
def on_purpose_demand_calculated(self, purpose: 'TourPurpose', demand: Dict[str, 'Demand']) -> None:
89+
"""
90+
Event handler that is called when purpose demand has been calculated.
91+
92+
Args:
93+
dm (DemandModel): The demand model.
94+
"""
95+
pass
96+
97+
def on_internal_demand_added(self, dtm: 'DepartureTimeModel') -> None:
98+
"""
99+
Event handler that is called when internal demand has been calculated.
100+
101+
Args:
102+
dtm (DepartureTimeModel): The departure time model.
103+
"""
104+
pass
105+
106+
def on_external_demand_calculated(self, demand: Dict[str, 'Demand']) -> None:
107+
"""
108+
Event handler that is called when external demand has been calculated.
109+
110+
Args:
111+
dtm (DepartureTimeModel): The departure time model.
112+
"""
113+
pass
114+
115+
def on_demand_calculated(self, iteration: Union[int, str], dtm: 'DepartureTimeModel') -> None:
116+
"""
117+
Event handler that is called when all demands has been added to the DTM.
118+
119+
Args:
120+
iteration (int | str): The iteration number.
121+
dtm (DepartureTimeModel): The departure time model.
122+
"""
123+
pass
124+
125+
def on_time_period_assigned(self, iteration: Union[int, str], ap: 'Period', impedance: Dict[str, Dict[str, np.ndarray]]) -> None:
126+
"""
127+
Event handler that is called when time period has been assigned.
128+
129+
Args:
130+
iteration (int | str): The iteration number.
131+
ap (Period): The assignment period.
132+
impedance (dict): The impedance matrices.
133+
"""
134+
pass
135+
136+
def on_iteration_complete(self, iteration: Union[int, str], impedance: Dict[str, Dict[str, np.ndarray]], gap: Dict[str, float]) -> None:
137+
"""
138+
Event handler that is called when an iteration is complete.
139+
140+
Args:
141+
iteration (int | str): The iteration number.
142+
"""
143+
pass
144+
145+
def on_emme_assignment_complete(self) -> None:
146+
pass
147+
148+
149+
class EventHandler(ModelSystemEventListener):
150+
"""Event handler that calls all equivalent methods in all other ModelSystemEventListener classes."""
151+
def __init__(self):
152+
"""Initialize the EventHandler.
153+
154+
Args:
155+
model_system (ModelSystem): ModelSystem instance.
156+
"""
157+
super().__init__()
158+
self.listeners = []
159+
self._create_methods()
160+
161+
def register_listener(self, listener: ModelSystemEventListener):
162+
self.listeners.append(listener)
163+
164+
def load_listeners(self, listener_path: Path):
165+
"""Load all listeners from a given path.
166+
167+
Args:
168+
listener_path (str): The path to the listeners.
169+
"""
170+
for file_path in listener_path.glob("*.py"):
171+
if file_path.name != "__init__.py":
172+
module_name = file_path.stem
173+
spec = importlib.util.spec_from_file_location(module_name, file_path)
174+
module = importlib.util.module_from_spec(spec)
175+
spec.loader.exec_module(module)
176+
for attr_name in dir(module):
177+
attr = getattr(module, attr_name)
178+
if isinstance(attr, type) and issubclass(attr, ModelSystemEventListener) and attr is not ModelSystemEventListener:
179+
self.register_listener(attr())
180+
log.info(f"Loaded listener {attr.__name__} from {file_path}")
181+
182+
183+
def _create_methods(self):
184+
"""Create methods that call all equivalent methods in all other ModelSystemEventListener classes.
185+
Methods area automatically created for all methods that start with "on_" in all ModelSystemEventListener classes.
186+
"""
187+
for method_name in dir(ModelSystemEventListener):
188+
if method_name.startswith("on_") and callable(getattr(ModelSystemEventListener, method_name)):
189+
setattr(self, method_name, self._create_method(method_name))
190+
191+
def _create_method(self, method_name):
192+
"""Create a method that calls all equivalent methods in all other ModelSystemEventListener classes.
193+
194+
Args:
195+
method_name (str): name of the method to create.
196+
"""
197+
def method(*args, **kwargs):
198+
for listener in self.listeners:
199+
try:
200+
getattr(listener, method_name)(*args, **kwargs)
201+
except Exception as e:
202+
if gettrace() is not None:
203+
raise e
204+
log.error(f"Error in {listener.__class__.__name__}.{method_name}: {e}")
205+
return method

0 commit comments

Comments
 (0)