Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Frequenz channels Release Notes

## New Features

- A new class `frequenz.channels.time_scheduler.TimeScheduler` was added. It's optimized for scenarios where events get added, rescheduled or canceled frequently.

## Bug Fixes

- `FileWatcher`: Fixed `ready()` method to return False when an error occurs. Before this fix, `select()` (and other code using `ready()`) never detected the `FileWatcher` was stopped and the `select()` loop was continuously waking up to inform the receiver was ready.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ dev-pylint = [
]
dev-pytest = [
"async-solipsism == 0.7",
"time-machine == 2.15.0",
"frequenz-repo-config[extra-lint-examples] == 0.10.0",
"hypothesis == 6.111.2",
"pytest == 8.3.2",
Expand Down
194 changes: 194 additions & 0 deletions src/frequenz/channels/timer_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# License: MIT
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH

"""Timer scheduler to schedule and events at specific times."""

import asyncio
import heapq
import itertools
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from typing import Dict, Generic, TypeVar

from frequenz.channels import Sender

T = TypeVar("T") # Generic type for the object associated with events


@dataclass(order=True)
class ScheduledEvent(Generic[T]):
"""Represents an event scheduled to be dispatched at a specific time."""

scheduled_time: datetime
obj: T = field(compare=False)
unique_id: int = field(compare=False, default=0)
canceled: bool = field(compare=False, default=False)


class TimerScheduler(Generic[T]):
"""Class to schedule and dispatch events at specific times.

Usage example:
```python

import asyncio
from frequenz.channels.timer_scheduler import TimerScheduler
from frequenz.channels import Broadcast
from datetime import timedelta

async def main():
event_channel = Broadcast[str](name="events")
sender = event_channel.new_sender()
receiver = event_channel.new_receiver()

scheduler = TimerScheduler[str](sender)

scheduler.set_timer(fire_in=timedelta(seconds=5), obj="event1")
scheduler.set_timer(fire_in=timedelta(seconds=10), obj="event2")
scheduler.set_timer(fire_in=timedelta(seconds=10), obj="event3")

# Waits 5 seconds and returns "event1"
assert await receiver.receive() == "event1"

# Remove the "event2" timer
scheduler.unset_timer("event2")

# Reschedule "event3" to fire in 15 seconds
scheduler.set_timer(fire_in=timedelta(seconds=15), obj="event3")

# Waits 15 more seconds and returns "event3"
assert await receiver.receive() == "event3"
"""

def __init__(self, sender: Sender[T]) -> None:
"""Initialize the TimerScheduler with the given sender.

Parameters:
sender: The sender to dispatch the events.
"""
self._sender = sender
self._event_heap: list[ScheduledEvent[T]] = []
self._obj_to_event: Dict[T, ScheduledEvent[T]] = {}
self._counter = itertools.count()
self._current_task: asyncio.Task[None] | None = None
self._stopped = False

def set_timer(
self,
*,
obj: T,
fire_in: timedelta | None = None,
fire_at: datetime | None = None,
) -> bool:
"""
Schedule a new event or reschedule an existing one.

Args:
obj: The object associated with the event.
fire_in: Time after which the event should be dispatched. Conflicts with fire_at.
fire_at: Time at which the event should be dispatched. Conflicts with fire_in.

Returns:
True if the event was successfully scheduled; False otherwise.
"""
now = datetime.now(timezone.utc)

scheduled_time = (now + fire_in) if fire_in else fire_at
assert scheduled_time, "Either 'fire_in' or 'fire_at' must be provided."

if scheduled_time < now:
return False

# Check if the object is already scheduled
if obj in self._obj_to_event:
existing_event = self._obj_to_event[obj]
existing_event.canceled = True # Mark the existing event as canceled

# Create a new scheduled event
unique_id = next(self._counter)
new_event = ScheduledEvent(
scheduled_time=scheduled_time, unique_id=unique_id, obj=obj
)
heapq.heappush(self._event_heap, new_event)
self._obj_to_event[obj] = new_event

# If the new event is the earliest, reset the waiting task
if self._event_heap[0] == new_event:
if self._current_task:
self._current_task.cancel()
self._current_task = asyncio.create_task(self._wait_and_dispatch())

return True

def unset_timer(self, obj: T) -> bool:
"""
Cancel a scheduled event associated with the given object.

Args:
obj: The object associated with the event to cancel.

Returns:
True if the event was found and canceled; False otherwise.
"""
if obj in self._obj_to_event:
existing_event = self._obj_to_event[obj]
existing_event.canceled = True # Mark the event as canceled
del self._obj_to_event[obj]

# If the canceled event was the next to be dispatched, reset the waiting task
if self._event_heap and self._event_heap[0].obj == obj:
if self._current_task:
self._current_task.cancel()
self._current_task = asyncio.create_task(self._wait_and_dispatch())

return True

return False

async def _wait_and_dispatch(self) -> None:
"""Wait for the next event to be due and dispatch it."""
while not self._stopped:
if not self._event_heap:
self._current_task = None
return

next_event = self._event_heap[0]
now = datetime.now(timezone.utc)
delay = (next_event.scheduled_time - now).total_seconds()

if delay <= 0:
# Check if the event still exists
if next_event.obj not in self._obj_to_event:
# Skip canceled events
heapq.heappop(self._event_heap)
continue

# Event is due
heapq.heappop(self._event_heap)
del self._obj_to_event[next_event.obj]

if next_event.canceled:
# Skip canceled events
continue

# Dispatch the event
await self._sender.send(next_event.obj)
continue # Check for the next event

try:
# Wait until the next event's scheduled_time
self._current_task = asyncio.create_task(asyncio.sleep(delay))
await self._current_task
except asyncio.CancelledError:
# A new earlier event was scheduled; exit to handle it
return

async def stop(self) -> None:
"""Stop the scheduler and cancel any pending tasks."""
self._stopped = True
if self._current_task:
self._current_task.cancel()
try:
await self._current_task
except asyncio.CancelledError:
pass
169 changes: 169 additions & 0 deletions tests/test_timer_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# License: MIT
# Copyright © 2023 Frequenz Energy-as-a-Service GmbH

"""Test module for the TimerScheduler class."""

import asyncio
from datetime import datetime, timedelta

import async_solipsism
import time_machine
from pytest import fixture

from frequenz.channels import Broadcast
from frequenz.channels.timer_scheduler import TimerScheduler


@fixture
def event_loop_policy() -> async_solipsism.EventLoopPolicy:
"""Return an event loop policy that uses the async solipsism event loop."""
return async_solipsism.EventLoopPolicy()


async def test_set_timer() -> None:
"""Test that a scheduled event is dispatched at the correct (mocked) time."""
# Create a Broadcast channel
bcast = Broadcast[str](name="test")

# Create a sender and receiver
sender = bcast.new_sender()
receiver = bcast.new_receiver()

# Initialize the TimerScheduler with the sender
sched = TimerScheduler(sender)

# List to collect received events
received_events = []

# Define the consumer coroutine
async def consumer() -> None:
async for event in receiver:
received_events.append(event)

# Start the consumer as an asyncio task
consumer_task = asyncio.create_task(consumer())

# Freeze time at 2024-01-01 12:00:00
with time_machine.travel(datetime(2024, 1, 1, 12, 0, 0)):
# Schedule 'event1' to fire in 10 seconds
sched.set_timer(fire_in=timedelta(seconds=10), obj="event1")

# Advance time to 2024-01-01 12:00:10 to make 'event1' due
with time_machine.travel(datetime(2024, 1, 1, 12, 0, 10)):
# Allow some time for the event to be dispatched
await asyncio.sleep(0.1)

# Assert that 'event1' was received
assert (
"event1" in received_events
), "The event 'event1' was not dispatched as expected."

# Clean up by stopping the scheduler and cancelling the consumer task
await sched.stop()
consumer_task.cancel()
try:
await consumer_task
except asyncio.CancelledError:
pass


async def test_reschedule_timer() -> None:
"""Test that rescheduling an event updates its dispatch time correctly."""
# Create a Broadcast channel
bcast = Broadcast[str](name="test")

# Create a sender and receiver
sender = bcast.new_sender()
receiver = bcast.new_receiver()

# Initialize the TimerScheduler with the sender
sched = TimerScheduler(sender)

# List to collect received events
received_events = []

# Define the consumer coroutine
async def consumer() -> None:
async for event in receiver:
received_events.append(event)

# Start the consumer as an asyncio task
consumer_task = asyncio.create_task(consumer())

# Freeze time at 2024-01-01 12:00:00
with time_machine.travel(datetime(2024, 1, 1, 12, 0, 0)):
# Schedule 'event1' to fire in 10 seconds
sched.set_timer(fire_in=timedelta(seconds=10), obj="event1")

# Reschedule 'event1' to fire in 5 seconds
sched.set_timer(fire_in=timedelta(seconds=5), obj="event1")

# Advance time to 2024-01-01 12:00:05 to make 'event1' due
with time_machine.travel(datetime(2024, 1, 1, 12, 0, 5)):
# Allow some time for the event to be dispatched
await asyncio.sleep(0.1)

# Assert that 'event1' was received only once
assert (
received_events.count("event1") == 1
), "The event 'event1' was dispatched multiple times."

# Clean up by stopping the scheduler and cancelling the consumer task
await sched.stop()
consumer_task.cancel()
try:
await consumer_task
except asyncio.CancelledError:
pass


async def test_unset_timer() -> None:
"""Test that cancelling a scheduled event prevents it from being dispatched."""
# Create a Broadcast channel
bcast = Broadcast[str](name="test")

# Create a sender and receiver
sender = bcast.new_sender()
receiver = bcast.new_receiver()

# Initialize the TimerScheduler with the sender
sched = TimerScheduler(sender)

# List to collect received events
received_events = []

# Define the consumer coroutine
async def consumer() -> None:
async for event in receiver:
received_events.append(event)

# Start the consumer as an asyncio task
consumer_task = asyncio.create_task(consumer())

# Freeze time at 2024-01-01 12:00:00
with time_machine.travel(datetime(2024, 1, 1, 12, 0, 0)):
# Schedule 'event1' to fire in 10 seconds
sched.set_timer(fire_in=timedelta(seconds=10), obj="event1")

# Advance time to 2024-01-01 12:00:05
with time_machine.travel(datetime(2024, 1, 1, 12, 0, 5)):
# Cancel 'event1' before it's due
sched.unset_timer("event1")

# Advance time to 2024-01-01 12:00:10 to reach the original dispatch time
with time_machine.travel(datetime(2024, 1, 1, 12, 0, 10)):
# Allow some time for the dispatcher to process
await asyncio.sleep(0.1)

# Assert that 'event1' was not received
assert (
"event1" not in received_events
), "The event 'event1' was dispatched despite being canceled."

# Clean up by stopping the scheduler and cancelling the consumer task
await sched.stop()
consumer_task.cancel()
try:
await consumer_task
except asyncio.CancelledError:
pass