1010import logging
1111from abc import ABC , abstractmethod
1212from collections .abc import Mapping
13+ from contextlib import closing
1314from dataclasses import dataclass , field
1415from datetime import datetime , timedelta , timezone
1516from heapq import heappop , heappush
@@ -113,13 +114,6 @@ def __init__(
113114 )
114115
115116 self ._running_state_status_tx = self ._running_state_status_channel .new_sender ()
116- self ._next_event_timer = Timer (
117- timedelta (seconds = 100 ), SkipMissedAndResync (), auto_start = False
118- )
119- """The timer to schedule the next event.
120-
121- Interval is chosen arbitrarily, as it will be reset on the first event.
122- """
123117
124118 self ._scheduled_events : list ["DispatchScheduler.QueueItem" ] = []
125119 """The scheduled events, sorted by time.
@@ -188,7 +182,7 @@ async def new_running_state_event_receiver(
188182 Raises:
189183 RuntimeError: If the dispatch service is not running.
190184 """
191- if not self ._tasks :
185+ if not self .is_running :
192186 raise RuntimeError ("Dispatch service not started" )
193187
194188 # Find all matching dispatches based on the type and collect them
@@ -230,44 +224,59 @@ async def _run(self) -> None:
230224 self ._microgrid_id ,
231225 )
232226
233- # Initial fetch
234- await self ._fetch ()
235-
236- stream = self ._client .stream (microgrid_id = self ._microgrid_id )
237-
238227 # Streaming updates
239- async for selected in select (self ._next_event_timer , stream ):
240- if selected_from (selected , self ._next_event_timer ):
241- if not self ._scheduled_events :
242- continue
243- await self ._execute_scheduled_event (
244- heappop (self ._scheduled_events ).dispatch
245- )
246- elif selected_from (selected , stream ):
247- _logger .debug ("Received dispatch event: %s" , selected .message )
248- dispatch = Dispatch (selected .message .dispatch )
249- match selected .message .event :
250- case Event .CREATED :
251- self ._dispatches [dispatch .id ] = dispatch
252- await self ._update_dispatch_schedule_and_notify (dispatch , None )
253- await self ._lifecycle_events_tx .send (Created (dispatch = dispatch ))
254- case Event .UPDATED :
255- await self ._update_dispatch_schedule_and_notify (
256- dispatch , self ._dispatches [dispatch .id ]
257- )
258- self ._dispatches [dispatch .id ] = dispatch
259- await self ._lifecycle_events_tx .send (Updated (dispatch = dispatch ))
260- case Event .DELETED :
261- self ._dispatches .pop (dispatch .id )
262- await self ._update_dispatch_schedule_and_notify (None , dispatch )
263-
264- await self ._lifecycle_events_tx .send (Deleted (dispatch = dispatch ))
265-
266- async def _execute_scheduled_event (self , dispatch : Dispatch ) -> None :
228+ with closing (
229+ Timer (timedelta (seconds = 100 ), SkipMissedAndResync (), auto_start = False )
230+ ) as next_event_timer :
231+ # Initial fetch
232+ await self ._fetch (next_event_timer )
233+ stream = self ._client .stream (microgrid_id = self ._microgrid_id )
234+
235+ async for selected in select (next_event_timer , stream ):
236+ if selected_from (selected , next_event_timer ):
237+ if not self ._scheduled_events :
238+ continue
239+ await self ._execute_scheduled_event (
240+ heappop (self ._scheduled_events ).dispatch , next_event_timer
241+ )
242+ elif selected_from (selected , stream ):
243+ _logger .debug ("Received dispatch event: %s" , selected .message )
244+ dispatch = Dispatch (selected .message .dispatch )
245+ match selected .message .event :
246+ case Event .CREATED :
247+ self ._dispatches [dispatch .id ] = dispatch
248+ await self ._update_dispatch_schedule_and_notify (
249+ dispatch , None , next_event_timer
250+ )
251+ await self ._lifecycle_events_tx .send (
252+ Created (dispatch = dispatch )
253+ )
254+ case Event .UPDATED :
255+ await self ._update_dispatch_schedule_and_notify (
256+ dispatch ,
257+ self ._dispatches [dispatch .id ],
258+ next_event_timer ,
259+ )
260+ self ._dispatches [dispatch .id ] = dispatch
261+ await self ._lifecycle_events_tx .send (
262+ Updated (dispatch = dispatch )
263+ )
264+ case Event .DELETED :
265+ self ._dispatches .pop (dispatch .id )
266+ await self ._update_dispatch_schedule_and_notify (
267+ None , dispatch , next_event_timer
268+ )
269+
270+ await self ._lifecycle_events_tx .send (
271+ Deleted (dispatch = dispatch )
272+ )
273+
274+ async def _execute_scheduled_event (self , dispatch : Dispatch , timer : Timer ) -> None :
267275 """Execute a scheduled event.
268276
269277 Args:
270278 dispatch: The dispatch to execute.
279+ timer: The timer to use for scheduling the next event.
271280 """
272281 _logger .debug ("Executing scheduled event: %s (%s)" , dispatch , dispatch .started )
273282 await self ._send_running_state_change (dispatch )
@@ -282,9 +291,9 @@ async def _execute_scheduled_event(self, dispatch: Dispatch) -> None:
282291 else :
283292 self ._schedule_start (dispatch )
284293
285- self ._update_timer ()
294+ self ._update_timer (timer )
286295
287- async def _fetch (self ) -> None :
296+ async def _fetch (self , timer : Timer ) -> None :
288297 """Fetch all relevant dispatches using list.
289298
290299 This is used for the initial fetch and for re-fetching all dispatches
@@ -305,12 +314,14 @@ async def _fetch(self) -> None:
305314 old_dispatch = old_dispatches .pop (dispatch .id , None )
306315 if not old_dispatch :
307316 _logger .debug ("New dispatch: %s" , dispatch )
308- await self ._update_dispatch_schedule_and_notify (dispatch , None )
317+ await self ._update_dispatch_schedule_and_notify (
318+ dispatch , None , timer
319+ )
309320 await self ._lifecycle_events_tx .send (Created (dispatch = dispatch ))
310321 elif dispatch .update_time != old_dispatch .update_time :
311322 _logger .debug ("Updated dispatch: %s" , dispatch )
312323 await self ._update_dispatch_schedule_and_notify (
313- dispatch , old_dispatch
324+ dispatch , old_dispatch , timer
314325 )
315326 await self ._lifecycle_events_tx .send (Updated (dispatch = dispatch ))
316327
@@ -324,7 +335,7 @@ async def _fetch(self) -> None:
324335 for dispatch in old_dispatches .values ():
325336 _logger .debug ("Deleted dispatch: %s" , dispatch )
326337 await self ._lifecycle_events_tx .send (Deleted (dispatch = dispatch ))
327- await self ._update_dispatch_schedule_and_notify (None , dispatch )
338+ await self ._update_dispatch_schedule_and_notify (None , dispatch , timer )
328339
329340 # Set deleted only here as it influences the result of dispatch.started
330341 # which is used in above in _running_state_change
@@ -334,7 +345,7 @@ async def _fetch(self) -> None:
334345 self ._initial_fetch_event .set ()
335346
336347 async def _update_dispatch_schedule_and_notify (
337- self , dispatch : Dispatch | None , old_dispatch : Dispatch | None
348+ self , dispatch : Dispatch | None , old_dispatch : Dispatch | None , timer : Timer
338349 ) -> None :
339350 """Update the schedule for a dispatch.
340351
@@ -350,6 +361,7 @@ async def _update_dispatch_schedule_and_notify(
350361 Args:
351362 dispatch: The dispatch to update the schedule for.
352363 old_dispatch: The old dispatch, if available.
364+ timer: The timer to use for scheduling the next event.
353365 """
354366 # If dispatch is None, the dispatch was deleted
355367 # and we need to cancel any existing event for it
@@ -392,13 +404,13 @@ async def _update_dispatch_schedule_and_notify(
392404 self ._schedule_start (dispatch )
393405
394406 # We modified the schedule, so we need to reset the timer
395- self ._update_timer ()
407+ self ._update_timer (timer )
396408
397- def _update_timer (self ) -> None :
409+ def _update_timer (self , timer : Timer ) -> None :
398410 """Update the timer to the next event."""
399411 if self ._scheduled_events :
400412 due_at : datetime = self ._scheduled_events [0 ].time
401- self . _next_event_timer .reset (interval = due_at - datetime .now (timezone .utc ))
413+ timer .reset (interval = due_at - datetime .now (timezone .utc ))
402414 _logger .debug ("Next event scheduled at %s" , self ._scheduled_events [0 ].time )
403415
404416 def _remove_scheduled (self , dispatch : Dispatch ) -> bool :
0 commit comments