diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 6d2e78c7..f33adf6a 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -1,5 +1,9 @@ # Frequenz channels Release Notes +## New Features + +- A new class `frequenz.channels.time_scheduler.TimeScheduler` was added. It's optimized for scenarios where events get added, rescheduled or canceled frequently. + ## Bug Fixes - `FileWatcher`: Fixed `ready()` method to return False when an error occurs. Before this fix, `select()` (and other code using `ready()`) never detected the `FileWatcher` was stopped and the `select()` loop was continuously waking up to inform the receiver was ready. diff --git a/pyproject.toml b/pyproject.toml index 12072ce0..4743663e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ dev-pylint = [ ] dev-pytest = [ "async-solipsism == 0.7", + "time-machine == 2.15.0", "frequenz-repo-config[extra-lint-examples] == 0.10.0", "hypothesis == 6.111.2", "pytest == 8.3.2", diff --git a/src/frequenz/channels/timer_scheduler.py b/src/frequenz/channels/timer_scheduler.py new file mode 100644 index 00000000..fa51c834 --- /dev/null +++ b/src/frequenz/channels/timer_scheduler.py @@ -0,0 +1,194 @@ +# License: MIT +# Copyright © 2024 Frequenz Energy-as-a-Service GmbH + +"""Timer scheduler to schedule and events at specific times.""" + +import asyncio +import heapq +import itertools +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from typing import Dict, Generic, TypeVar + +from frequenz.channels import Sender + +T = TypeVar("T") # Generic type for the object associated with events + + +@dataclass(order=True) +class ScheduledEvent(Generic[T]): + """Represents an event scheduled to be dispatched at a specific time.""" + + scheduled_time: datetime + obj: T = field(compare=False) + unique_id: int = field(compare=False, default=0) + canceled: bool = field(compare=False, default=False) + + +class TimerScheduler(Generic[T]): + """Class to schedule and dispatch events at specific times. + + Usage example: + ```python + + import asyncio + from frequenz.channels.timer_scheduler import TimerScheduler + from frequenz.channels import Broadcast + from datetime import timedelta + + async def main(): + event_channel = Broadcast[str](name="events") + sender = event_channel.new_sender() + receiver = event_channel.new_receiver() + + scheduler = TimerScheduler[str](sender) + + scheduler.set_timer(fire_in=timedelta(seconds=5), obj="event1") + scheduler.set_timer(fire_in=timedelta(seconds=10), obj="event2") + scheduler.set_timer(fire_in=timedelta(seconds=10), obj="event3") + + # Waits 5 seconds and returns "event1" + assert await receiver.receive() == "event1" + + # Remove the "event2" timer + scheduler.unset_timer("event2") + + # Reschedule "event3" to fire in 15 seconds + scheduler.set_timer(fire_in=timedelta(seconds=15), obj="event3") + + # Waits 15 more seconds and returns "event3" + assert await receiver.receive() == "event3" + """ + + def __init__(self, sender: Sender[T]) -> None: + """Initialize the TimerScheduler with the given sender. + + Parameters: + sender: The sender to dispatch the events. + """ + self._sender = sender + self._event_heap: list[ScheduledEvent[T]] = [] + self._obj_to_event: Dict[T, ScheduledEvent[T]] = {} + self._counter = itertools.count() + self._current_task: asyncio.Task[None] | None = None + self._stopped = False + + def set_timer( + self, + *, + obj: T, + fire_in: timedelta | None = None, + fire_at: datetime | None = None, + ) -> bool: + """ + Schedule a new event or reschedule an existing one. + + Args: + obj: The object associated with the event. + fire_in: Time after which the event should be dispatched. Conflicts with fire_at. + fire_at: Time at which the event should be dispatched. Conflicts with fire_in. + + Returns: + True if the event was successfully scheduled; False otherwise. + """ + now = datetime.now(timezone.utc) + + scheduled_time = (now + fire_in) if fire_in else fire_at + assert scheduled_time, "Either 'fire_in' or 'fire_at' must be provided." + + if scheduled_time < now: + return False + + # Check if the object is already scheduled + if obj in self._obj_to_event: + existing_event = self._obj_to_event[obj] + existing_event.canceled = True # Mark the existing event as canceled + + # Create a new scheduled event + unique_id = next(self._counter) + new_event = ScheduledEvent( + scheduled_time=scheduled_time, unique_id=unique_id, obj=obj + ) + heapq.heappush(self._event_heap, new_event) + self._obj_to_event[obj] = new_event + + # If the new event is the earliest, reset the waiting task + if self._event_heap[0] == new_event: + if self._current_task: + self._current_task.cancel() + self._current_task = asyncio.create_task(self._wait_and_dispatch()) + + return True + + def unset_timer(self, obj: T) -> bool: + """ + Cancel a scheduled event associated with the given object. + + Args: + obj: The object associated with the event to cancel. + + Returns: + True if the event was found and canceled; False otherwise. + """ + if obj in self._obj_to_event: + existing_event = self._obj_to_event[obj] + existing_event.canceled = True # Mark the event as canceled + del self._obj_to_event[obj] + + # If the canceled event was the next to be dispatched, reset the waiting task + if self._event_heap and self._event_heap[0].obj == obj: + if self._current_task: + self._current_task.cancel() + self._current_task = asyncio.create_task(self._wait_and_dispatch()) + + return True + + return False + + async def _wait_and_dispatch(self) -> None: + """Wait for the next event to be due and dispatch it.""" + while not self._stopped: + if not self._event_heap: + self._current_task = None + return + + next_event = self._event_heap[0] + now = datetime.now(timezone.utc) + delay = (next_event.scheduled_time - now).total_seconds() + + if delay <= 0: + # Check if the event still exists + if next_event.obj not in self._obj_to_event: + # Skip canceled events + heapq.heappop(self._event_heap) + continue + + # Event is due + heapq.heappop(self._event_heap) + del self._obj_to_event[next_event.obj] + + if next_event.canceled: + # Skip canceled events + continue + + # Dispatch the event + await self._sender.send(next_event.obj) + continue # Check for the next event + + try: + # Wait until the next event's scheduled_time + self._current_task = asyncio.create_task(asyncio.sleep(delay)) + await self._current_task + except asyncio.CancelledError: + # A new earlier event was scheduled; exit to handle it + return + + async def stop(self) -> None: + """Stop the scheduler and cancel any pending tasks.""" + self._stopped = True + if self._current_task: + self._current_task.cancel() + try: + await self._current_task + except asyncio.CancelledError: + pass diff --git a/tests/test_timer_scheduler.py b/tests/test_timer_scheduler.py new file mode 100644 index 00000000..ae8e309e --- /dev/null +++ b/tests/test_timer_scheduler.py @@ -0,0 +1,169 @@ +# License: MIT +# Copyright © 2023 Frequenz Energy-as-a-Service GmbH + +"""Test module for the TimerScheduler class.""" + +import asyncio +from datetime import datetime, timedelta + +import async_solipsism +import time_machine +from pytest import fixture + +from frequenz.channels import Broadcast +from frequenz.channels.timer_scheduler import TimerScheduler + + +@fixture +def event_loop_policy() -> async_solipsism.EventLoopPolicy: + """Return an event loop policy that uses the async solipsism event loop.""" + return async_solipsism.EventLoopPolicy() + + +async def test_set_timer() -> None: + """Test that a scheduled event is dispatched at the correct (mocked) time.""" + # Create a Broadcast channel + bcast = Broadcast[str](name="test") + + # Create a sender and receiver + sender = bcast.new_sender() + receiver = bcast.new_receiver() + + # Initialize the TimerScheduler with the sender + sched = TimerScheduler(sender) + + # List to collect received events + received_events = [] + + # Define the consumer coroutine + async def consumer() -> None: + async for event in receiver: + received_events.append(event) + + # Start the consumer as an asyncio task + consumer_task = asyncio.create_task(consumer()) + + # Freeze time at 2024-01-01 12:00:00 + with time_machine.travel(datetime(2024, 1, 1, 12, 0, 0)): + # Schedule 'event1' to fire in 10 seconds + sched.set_timer(fire_in=timedelta(seconds=10), obj="event1") + + # Advance time to 2024-01-01 12:00:10 to make 'event1' due + with time_machine.travel(datetime(2024, 1, 1, 12, 0, 10)): + # Allow some time for the event to be dispatched + await asyncio.sleep(0.1) + + # Assert that 'event1' was received + assert ( + "event1" in received_events + ), "The event 'event1' was not dispatched as expected." + + # Clean up by stopping the scheduler and cancelling the consumer task + await sched.stop() + consumer_task.cancel() + try: + await consumer_task + except asyncio.CancelledError: + pass + + +async def test_reschedule_timer() -> None: + """Test that rescheduling an event updates its dispatch time correctly.""" + # Create a Broadcast channel + bcast = Broadcast[str](name="test") + + # Create a sender and receiver + sender = bcast.new_sender() + receiver = bcast.new_receiver() + + # Initialize the TimerScheduler with the sender + sched = TimerScheduler(sender) + + # List to collect received events + received_events = [] + + # Define the consumer coroutine + async def consumer() -> None: + async for event in receiver: + received_events.append(event) + + # Start the consumer as an asyncio task + consumer_task = asyncio.create_task(consumer()) + + # Freeze time at 2024-01-01 12:00:00 + with time_machine.travel(datetime(2024, 1, 1, 12, 0, 0)): + # Schedule 'event1' to fire in 10 seconds + sched.set_timer(fire_in=timedelta(seconds=10), obj="event1") + + # Reschedule 'event1' to fire in 5 seconds + sched.set_timer(fire_in=timedelta(seconds=5), obj="event1") + + # Advance time to 2024-01-01 12:00:05 to make 'event1' due + with time_machine.travel(datetime(2024, 1, 1, 12, 0, 5)): + # Allow some time for the event to be dispatched + await asyncio.sleep(0.1) + + # Assert that 'event1' was received only once + assert ( + received_events.count("event1") == 1 + ), "The event 'event1' was dispatched multiple times." + + # Clean up by stopping the scheduler and cancelling the consumer task + await sched.stop() + consumer_task.cancel() + try: + await consumer_task + except asyncio.CancelledError: + pass + + +async def test_unset_timer() -> None: + """Test that cancelling a scheduled event prevents it from being dispatched.""" + # Create a Broadcast channel + bcast = Broadcast[str](name="test") + + # Create a sender and receiver + sender = bcast.new_sender() + receiver = bcast.new_receiver() + + # Initialize the TimerScheduler with the sender + sched = TimerScheduler(sender) + + # List to collect received events + received_events = [] + + # Define the consumer coroutine + async def consumer() -> None: + async for event in receiver: + received_events.append(event) + + # Start the consumer as an asyncio task + consumer_task = asyncio.create_task(consumer()) + + # Freeze time at 2024-01-01 12:00:00 + with time_machine.travel(datetime(2024, 1, 1, 12, 0, 0)): + # Schedule 'event1' to fire in 10 seconds + sched.set_timer(fire_in=timedelta(seconds=10), obj="event1") + + # Advance time to 2024-01-01 12:00:05 + with time_machine.travel(datetime(2024, 1, 1, 12, 0, 5)): + # Cancel 'event1' before it's due + sched.unset_timer("event1") + + # Advance time to 2024-01-01 12:00:10 to reach the original dispatch time + with time_machine.travel(datetime(2024, 1, 1, 12, 0, 10)): + # Allow some time for the dispatcher to process + await asyncio.sleep(0.1) + + # Assert that 'event1' was not received + assert ( + "event1" not in received_events + ), "The event 'event1' was dispatched despite being canceled." + + # Clean up by stopping the scheduler and cancelling the consumer task + await sched.stop() + consumer_task.cancel() + try: + await consumer_task + except asyncio.CancelledError: + pass