Skip to content

Commit 38ca52b

Browse files
committed
Add scheduling class optimized for frequently changing timers
Signed-off-by: Mathias L. Baumann <[email protected]>
1 parent 62c2248 commit 38ca52b

File tree

4 files changed

+364
-0
lines changed

4 files changed

+364
-0
lines changed

RELEASE_NOTES.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Frequenz channels Release Notes
22

3+
## New Features
4+
5+
- A new class `frequenz.channels.time_scheduler.TimeScheduler` was added. It's optimized for scenarios where events get added, rescheduled or canceled frequently.
6+
37
## Bug Fixes
48

59
- `FileWatcher`: Fixed `ready()` method to return False when an error occurs. Before this fix, `select()` (and other code using `ready()`) never detected the `FileWatcher` was stopped and the `select()` loop was continuously waking up to inform the receiver was ready.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ dev-pylint = [
7474
]
7575
dev-pytest = [
7676
"async-solipsism == 0.7",
77+
"time-machine == 2.15.0",
7778
"frequenz-repo-config[extra-lint-examples] == 0.10.0",
7879
"hypothesis == 6.111.2",
7980
"pytest == 8.3.2",
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# License: MIT
2+
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Timer scheduler to schedule and events at specific times."""
5+
6+
import asyncio
7+
import heapq
8+
import itertools
9+
from dataclasses import dataclass, field
10+
from datetime import datetime, timedelta, timezone
11+
from typing import Dict, Generic, TypeVar
12+
13+
from frequenz.channels import Sender
14+
15+
T = TypeVar("T") # Generic type for the object associated with events
16+
17+
18+
@dataclass(order=True)
19+
class ScheduledEvent(Generic[T]):
20+
"""Represents an event scheduled to be dispatched at a specific time."""
21+
22+
scheduled_time: datetime
23+
obj: T = field(compare=False)
24+
unique_id: int = field(compare=False, default=0)
25+
canceled: bool = field(compare=False, default=False)
26+
27+
28+
class TimerScheduler(Generic[T]):
29+
"""Class to schedule and dispatch events at specific times.
30+
31+
Usage example:
32+
```python
33+
34+
import asyncio
35+
from frequenz.channels.timer_scheduler import TimerScheduler
36+
from frequenz.channels import Broadcast
37+
from datetime import timedelta
38+
39+
async def main():
40+
event_channel = Broadcast[str](name="events")
41+
sender = event_channel.new_sender()
42+
receiver = event_channel.new_receiver()
43+
44+
scheduler = TimerScheduler[str](sender)
45+
46+
scheduler.set_timer(fire_in=timedelta(seconds=5), obj="event1")
47+
scheduler.set_timer(fire_in=timedelta(seconds=10), obj="event2")
48+
scheduler.set_timer(fire_in=timedelta(seconds=10), obj="event3")
49+
50+
# Waits 5 seconds and returns "event1"
51+
assert await receiver.receive() == "event1"
52+
53+
# Remove the "event2" timer
54+
scheduler.unset_timer("event2")
55+
56+
# Reschedule "event3" to fire in 15 seconds
57+
scheduler.set_timer(fire_in=timedelta(seconds=15), obj="event3")
58+
59+
# Waits 15 more seconds and returns "event3"
60+
assert await receiver.receive() == "event3"
61+
"""
62+
63+
def __init__(self, sender: Sender[T]) -> None:
64+
"""Initialize the TimerScheduler with the given sender.
65+
66+
Parameters:
67+
sender: The sender to dispatch the events.
68+
"""
69+
self._sender = sender
70+
self._event_heap: list[ScheduledEvent[T]] = []
71+
self._obj_to_event: Dict[T, ScheduledEvent[T]] = {}
72+
self._counter = itertools.count()
73+
self._current_task: asyncio.Task[None] | None = None
74+
self._stopped = False
75+
76+
def set_timer(
77+
self,
78+
*,
79+
obj: T,
80+
fire_in: timedelta | None = None,
81+
fire_at: datetime | None = None,
82+
) -> None:
83+
"""
84+
Schedule a new event or reschedule an existing one.
85+
86+
Args:
87+
obj: The object associated with the event.
88+
fire_in: Time after which the event should be dispatched. Conflicts with fire_at.
89+
fire_at: Time at which the event should be dispatched. Conflicts with fire_in.
90+
"""
91+
now = datetime.now(timezone.utc)
92+
93+
scheduled_time = (now + fire_in) if fire_in else fire_at
94+
assert scheduled_time, "Either 'fire_in' or 'fire_at' must be provided."
95+
96+
# Check if the object is already scheduled
97+
if obj in self._obj_to_event:
98+
existing_event = self._obj_to_event[obj]
99+
existing_event.canceled = True # Mark the existing event as canceled
100+
101+
# Create a new scheduled event
102+
unique_id = next(self._counter)
103+
new_event = ScheduledEvent(
104+
scheduled_time=scheduled_time, unique_id=unique_id, obj=obj
105+
)
106+
heapq.heappush(self._event_heap, new_event)
107+
self._obj_to_event[obj] = new_event
108+
109+
# If the new event is the earliest, reset the waiting task
110+
if self._event_heap[0] == new_event:
111+
if self._current_task:
112+
self._current_task.cancel()
113+
self._current_task = asyncio.create_task(self._wait_and_dispatch())
114+
115+
def unset_timer(self, obj: T) -> bool:
116+
"""
117+
Cancel a scheduled event associated with the given object.
118+
119+
Args:
120+
obj: The object associated with the event to cancel.
121+
122+
Returns:
123+
True if the event was found and canceled; False otherwise.
124+
"""
125+
if obj in self._obj_to_event:
126+
existing_event = self._obj_to_event[obj]
127+
existing_event.canceled = True # Mark the event as canceled
128+
del self._obj_to_event[obj]
129+
130+
# If the canceled event was the next to be dispatched, reset the waiting task
131+
if self._event_heap and self._event_heap[0].obj == obj:
132+
if self._current_task:
133+
self._current_task.cancel()
134+
self._current_task = asyncio.create_task(self._wait_and_dispatch())
135+
136+
return True
137+
138+
return False
139+
140+
async def _wait_and_dispatch(self) -> None:
141+
"""Wait for the next event to be due and dispatch it."""
142+
while not self._stopped:
143+
if not self._event_heap:
144+
self._current_task = None
145+
return
146+
147+
next_event = self._event_heap[0]
148+
now = datetime.now(timezone.utc)
149+
delay = (next_event.scheduled_time - now).total_seconds()
150+
151+
if delay <= 0:
152+
# Check if the event still exists
153+
if next_event.obj not in self._obj_to_event:
154+
# Skip canceled events
155+
heapq.heappop(self._event_heap)
156+
continue
157+
158+
# Event is due
159+
heapq.heappop(self._event_heap)
160+
del self._obj_to_event[next_event.obj]
161+
162+
if next_event.canceled:
163+
# Skip canceled events
164+
continue
165+
166+
# Dispatch the event
167+
await self._sender.send(next_event.obj)
168+
continue # Check for the next event
169+
170+
try:
171+
# Wait until the next event's scheduled_time
172+
self._current_task = asyncio.create_task(asyncio.sleep(delay))
173+
await self._current_task
174+
except asyncio.CancelledError:
175+
# A new earlier event was scheduled; exit to handle it
176+
return
177+
178+
async def stop(self) -> None:
179+
"""Stop the scheduler and cancel any pending tasks."""
180+
self._stopped = True
181+
if self._current_task:
182+
self._current_task.cancel()
183+
try:
184+
await self._current_task
185+
except asyncio.CancelledError:
186+
pass

tests/test_timer_scheduler.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# License: MIT
2+
# Copyright © 2023 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Test module for the TimerScheduler class."""
5+
6+
import asyncio
7+
from datetime import datetime, timedelta
8+
9+
import async_solipsism
10+
import pytest
11+
import time_machine
12+
from pytest import fixture
13+
14+
from frequenz.channels import Broadcast
15+
from frequenz.channels.timer_scheduler import TimerScheduler
16+
17+
18+
@fixture
19+
def event_loop_policy() -> async_solipsism.EventLoopPolicy:
20+
"""Return an event loop policy that uses the async solipsism event loop."""
21+
return async_solipsism.EventLoopPolicy()
22+
23+
24+
@pytest.mark.asyncio
25+
async def test_set_timer():
26+
"""Test that a scheduled event is dispatched at the correct (mocked) time."""
27+
# Create a Broadcast channel
28+
bcast = Broadcast(name="test")
29+
30+
# Create a sender and receiver
31+
sender = bcast.new_sender()
32+
receiver = bcast.new_receiver()
33+
34+
# Initialize the TimerScheduler with the sender
35+
sched = TimerScheduler(sender)
36+
37+
# List to collect received events
38+
received_events = []
39+
40+
# Define the consumer coroutine
41+
async def consumer():
42+
async for event in receiver:
43+
received_events.append(event)
44+
45+
# Start the consumer as an asyncio task
46+
consumer_task = asyncio.create_task(consumer())
47+
48+
# Freeze time at 2024-01-01 12:00:00
49+
with time_machine.travel(datetime(2024, 1, 1, 12, 0, 0)):
50+
# Schedule 'event1' to fire in 10 seconds
51+
sched.set_timer(fire_in=timedelta(seconds=10), obj="event1")
52+
53+
# Advance time to 2024-01-01 12:00:10 to make 'event1' due
54+
with time_machine.travel(datetime(2024, 1, 1, 12, 0, 10)):
55+
# Allow some time for the event to be dispatched
56+
await asyncio.sleep(0.1)
57+
58+
# Assert that 'event1' was received
59+
assert (
60+
"event1" in received_events
61+
), "The event 'event1' was not dispatched as expected."
62+
63+
# Clean up by stopping the scheduler and cancelling the consumer task
64+
await sched.stop()
65+
consumer_task.cancel()
66+
try:
67+
await consumer_task
68+
except asyncio.CancelledError:
69+
pass
70+
71+
72+
@pytest.mark.asyncio
73+
async def test_reschedule_timer():
74+
"""Test that rescheduling an event updates its dispatch time correctly."""
75+
# Create a Broadcast channel
76+
bcast = Broadcast(name="test")
77+
78+
# Create a sender and receiver
79+
sender = bcast.new_sender()
80+
receiver = bcast.new_receiver()
81+
82+
# Initialize the TimerScheduler with the sender
83+
sched = TimerScheduler(sender)
84+
85+
# List to collect received events
86+
received_events = []
87+
88+
# Define the consumer coroutine
89+
async def consumer():
90+
async for event in receiver:
91+
received_events.append(event)
92+
93+
# Start the consumer as an asyncio task
94+
consumer_task = asyncio.create_task(consumer())
95+
96+
# Freeze time at 2024-01-01 12:00:00
97+
with time_machine.travel(datetime(2024, 1, 1, 12, 0, 0)):
98+
# Schedule 'event1' to fire in 10 seconds
99+
sched.set_timer(fire_in=timedelta(seconds=10), obj="event1")
100+
101+
# Reschedule 'event1' to fire in 5 seconds
102+
sched.set_timer(fire_in=timedelta(seconds=5), obj="event1")
103+
104+
# Advance time to 2024-01-01 12:00:05 to make 'event1' due
105+
with time_machine.travel(datetime(2024, 1, 1, 12, 0, 5)):
106+
# Allow some time for the event to be dispatched
107+
await asyncio.sleep(0.1)
108+
109+
# Assert that 'event1' was received only once
110+
assert (
111+
received_events.count("event1") == 1
112+
), "The event 'event1' was dispatched multiple times."
113+
114+
# Clean up by stopping the scheduler and cancelling the consumer task
115+
await sched.stop()
116+
consumer_task.cancel()
117+
try:
118+
await consumer_task
119+
except asyncio.CancelledError:
120+
pass
121+
122+
123+
@pytest.mark.asyncio
124+
async def test_unset_timer():
125+
"""Test that cancelling a scheduled event prevents it from being dispatched."""
126+
# Create a Broadcast channel
127+
bcast = Broadcast(name="test")
128+
129+
# Create a sender and receiver
130+
sender = bcast.new_sender()
131+
receiver = bcast.new_receiver()
132+
133+
# Initialize the TimerScheduler with the sender
134+
sched = TimerScheduler(sender)
135+
136+
# List to collect received events
137+
received_events = []
138+
139+
# Define the consumer coroutine
140+
async def consumer():
141+
async for event in receiver:
142+
received_events.append(event)
143+
144+
# Start the consumer as an asyncio task
145+
consumer_task = asyncio.create_task(consumer())
146+
147+
# Freeze time at 2024-01-01 12:00:00
148+
with time_machine.travel(datetime(2024, 1, 1, 12, 0, 0)):
149+
# Schedule 'event1' to fire in 10 seconds
150+
sched.set_timer(fire_in=timedelta(seconds=10), obj="event1")
151+
152+
# Advance time to 2024-01-01 12:00:05
153+
with time_machine.travel(datetime(2024, 1, 1, 12, 0, 5)):
154+
# Cancel 'event1' before it's due
155+
sched.unset_timer("event1")
156+
157+
# Advance time to 2024-01-01 12:00:10 to reach the original dispatch time
158+
with time_machine.travel(datetime(2024, 1, 1, 12, 0, 10)):
159+
# Allow some time for the dispatcher to process
160+
await asyncio.sleep(0.1)
161+
162+
# Assert that 'event1' was not received
163+
assert (
164+
"event1" not in received_events
165+
), "The event 'event1' was dispatched despite being canceled."
166+
167+
# Clean up by stopping the scheduler and cancelling the consumer task
168+
await sched.stop()
169+
consumer_task.cancel()
170+
try:
171+
await consumer_task
172+
except asyncio.CancelledError:
173+
pass

0 commit comments

Comments
 (0)