Skip to content

Commit 3fd9489

Browse files
committed
Make MovingWindow a BackgroundService
This gives the moving window a common interface and removes some bookkeeping from the implementation, as well as making tests more correct (when using it as an async context manager). Signed-off-by: Leandro Lucarella <[email protected]>
1 parent 0f985e2 commit 3fd9489

File tree

5 files changed

+192
-174
lines changed

5 files changed

+192
-174
lines changed

benchmarks/timeseries/periodic_feature_extractor.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from __future__ import annotations
1313

1414
import asyncio
15+
import collections.abc
16+
import contextlib
1517
import logging
1618
from datetime import datetime, timedelta, timezone
1719
from functools import partial
@@ -27,19 +29,23 @@
2729
from frequenz.sdk.timeseries._quantities import Quantity
2830

2931

30-
async def init_feature_extractor(period: int) -> PeriodicFeatureExtractor:
32+
@contextlib.asynccontextmanager
33+
async def init_feature_extractor(
34+
period: int,
35+
) -> collections.abc.AsyncIterator[PeriodicFeatureExtractor]:
3136
"""Initialize the PeriodicFeatureExtractor class."""
3237
# We only need the moving window to initialize the PeriodicFeatureExtractor class.
3338
lm_chan = Broadcast[Sample[Quantity]]("lm_net_power")
34-
moving_window = MovingWindow(
39+
async with MovingWindow(
3540
timedelta(seconds=1), lm_chan.new_receiver(), timedelta(seconds=1)
36-
)
37-
38-
await lm_chan.new_sender().send(Sample(datetime.now(tz=timezone.utc), Quantity(0)))
41+
) as moving_window:
42+
await lm_chan.new_sender().send(
43+
Sample(datetime.now(tz=timezone.utc), Quantity(0))
44+
)
3945

40-
# Initialize the PeriodicFeatureExtractor class with a period of period seconds.
41-
# This works since the sampling period is set to 1 second.
42-
return PeriodicFeatureExtractor(moving_window, timedelta(seconds=period))
46+
# Initialize the PeriodicFeatureExtractor class with a period of period seconds.
47+
# This works since the sampling period is set to 1 second.
48+
yield PeriodicFeatureExtractor(moving_window, timedelta(seconds=period))
4349

4450

4551
def _calculate_avg_window(
@@ -211,22 +217,22 @@ async def main() -> None:
211217

212218
# create a random ndarray with 29 days -5 seconds of data
213219
days_29_s = 29 * DAY_S
214-
feature_extractor = await init_feature_extractor(10)
215-
data = rng.standard_normal(days_29_s)
216-
run_benchmark(data, 4, feature_extractor)
217-
218-
days_29_s = 29 * DAY_S + 3
219-
data = rng.standard_normal(days_29_s)
220-
run_benchmark(data, 4, feature_extractor)
221-
222-
# create a random ndarray with 29 days +5 seconds of data
223-
data = rng.standard_normal(29 * DAY_S + 5)
224-
225-
feature_extractor = await init_feature_extractor(7 * DAY_S)
226-
# TEST one day window and 6 days distance. COPY (Case 3)
227-
run_benchmark(data, DAY_S, feature_extractor)
228-
# benchmark one day window and 6 days distance. NO COPY (Case 1)
229-
run_benchmark(data[: 28 * DAY_S], DAY_S, feature_extractor)
220+
async with init_feature_extractor(10) as feature_extractor:
221+
data = rng.standard_normal(days_29_s)
222+
run_benchmark(data, 4, feature_extractor)
223+
224+
days_29_s = 29 * DAY_S + 3
225+
data = rng.standard_normal(days_29_s)
226+
run_benchmark(data, 4, feature_extractor)
227+
228+
# create a random ndarray with 29 days +5 seconds of data
229+
data = rng.standard_normal(29 * DAY_S + 5)
230+
231+
async with init_feature_extractor(7 * DAY_S) as feature_extractor:
232+
# TEST one day window and 6 days distance. COPY (Case 3)
233+
run_benchmark(data, DAY_S, feature_extractor)
234+
# benchmark one day window and 6 days distance. NO COPY (Case 1)
235+
run_benchmark(data[: 28 * DAY_S], DAY_S, feature_extractor)
230236

231237

232238
logging.basicConfig(level=logging.DEBUG)

src/frequenz/sdk/timeseries/_moving_window.py

Lines changed: 34 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from frequenz.channels import Broadcast, Receiver, Sender
1717
from numpy.typing import ArrayLike
1818

19-
from .._internal._asyncio import cancel_and_await
19+
from ..actor._background_service import BackgroundService
2020
from ._base_types import UNIX_EPOCH, Sample
2121
from ._quantities import Quantity
2222
from ._resampling import Resampler, ResamplerConfig
@@ -25,7 +25,7 @@
2525
_logger = logging.getLogger(__name__)
2626

2727

28-
class MovingWindow:
28+
class MovingWindow(BackgroundService):
2929
"""
3030
A data window that moves with the latest datapoints of a data stream.
3131
@@ -72,22 +72,21 @@ async def run() -> None:
7272
7373
send_task = asyncio.create_task(send_mock_data(resampled_data_sender))
7474
75-
window = MovingWindow(
75+
async with MovingWindow(
7676
size=timedelta(seconds=5),
7777
resampled_data_recv=resampled_data_receiver,
7878
input_sampling_period=timedelta(seconds=1),
79-
)
79+
) as window:
80+
time_start = datetime.now(tz=timezone.utc)
81+
time_end = time_start + timedelta(seconds=5)
8082
81-
time_start = datetime.now(tz=timezone.utc)
82-
time_end = time_start + timedelta(seconds=5)
83+
# ... wait for 5 seconds until the buffer is filled
84+
await asyncio.sleep(5)
8385
84-
# ... wait for 5 seconds until the buffer is filled
85-
await asyncio.sleep(5)
86-
87-
# return an numpy array from the window
88-
array = window[time_start:time_end]
89-
# and use it to for example calculate the mean
90-
mean = array.mean()
86+
# return an numpy array from the window
87+
array = window[time_start:time_end]
88+
# and use it to for example calculate the mean
89+
mean = array.mean()
9190
9291
asyncio.run(run())
9392
```
@@ -112,19 +111,18 @@ async def run() -> None:
112111
113112
# create a window that stores two days of data
114113
# starting at 1.1.23 with samplerate=1
115-
window = MovingWindow(
114+
async with MovingWindow(
116115
size=timedelta(days=2),
117116
resampled_data_recv=resampled_data_receiver,
118117
input_sampling_period=timedelta(seconds=1),
119-
)
120-
121-
# wait for one full day until the buffer is filled
122-
await asyncio.sleep(60*60*24)
118+
) as window:
119+
# wait for one full day until the buffer is filled
120+
await asyncio.sleep(60*60*24)
123121
124-
# create a polars series with one full day of data
125-
time_start = datetime(2023, 1, 1, tzinfo=timezone.utc)
126-
time_end = datetime(2023, 1, 2, tzinfo=timezone.utc)
127-
series = pl.Series("Jan_1", window[time_start:time_end])
122+
# create a polars series with one full day of data
123+
time_start = datetime(2023, 1, 1, tzinfo=timezone.utc)
124+
time_end = datetime(2023, 1, 2, tzinfo=timezone.utc)
125+
series = pl.Series("Jan_1", window[time_start:time_end])
128126
129127
asyncio.run(run())
130128
```
@@ -137,6 +135,8 @@ def __init__( # pylint: disable=too-many-arguments
137135
input_sampling_period: timedelta,
138136
resampler_config: ResamplerConfig | None = None,
139137
align_to: datetime = UNIX_EPOCH,
138+
*,
139+
name: str | None = None,
140140
) -> None:
141141
"""
142142
Initialize the MovingWindow.
@@ -154,22 +154,21 @@ def __init__( # pylint: disable=too-many-arguments
154154
align_to: A datetime object that defines a point in time to which
155155
the window is aligned to modulo window size. For further
156156
information, consult the class level documentation.
157-
158-
Raises:
159-
asyncio.CancelledError: when the task gets cancelled.
157+
name: The name of this moving window. If `None`, `str(id(self))` will be
158+
used. This is used mostly for debugging purposes.
160159
"""
161160
assert (
162161
input_sampling_period.total_seconds() > 0
163162
), "The input sampling period should be greater than zero."
164163
assert (
165164
input_sampling_period <= size
166165
), "The input sampling period should be equal to or lower than the window size."
166+
super().__init__(name=name)
167167

168168
self._sampling_period = input_sampling_period
169169

170170
self._resampler: Resampler | None = None
171171
self._resampler_sender: Sender[Sample[Quantity]] | None = None
172-
self._resampler_task: asyncio.Task[None] | None = None
173172

174173
if resampler_config:
175174
assert (
@@ -191,12 +190,14 @@ def __init__( # pylint: disable=too-many-arguments
191190
align_to=align_to,
192191
)
193192

193+
async def start(self) -> None:
194+
"""Start the MovingWindow.
195+
196+
This method starts the MovingWindow tasks.
197+
"""
194198
if self._resampler:
195199
self._configure_resampler()
196-
197-
self._update_window_task: asyncio.Task[None] = asyncio.create_task(
198-
self._run_impl()
199-
)
200+
self._tasks.add(asyncio.create_task(self._run_impl(), name="update-window"))
200201

201202
@property
202203
def sampling_period(self) -> timedelta:
@@ -228,12 +229,6 @@ async def _run_impl(self) -> None:
228229

229230
_logger.error("Channel has been closed")
230231

231-
async def stop(self) -> None:
232-
"""Cancel the running tasks and stop the MovingWindow."""
233-
await cancel_and_await(self._update_window_task)
234-
if self._resampler_task:
235-
await cancel_and_await(self._resampler_task)
236-
237232
def _configure_resampler(self) -> None:
238233
"""Configure the components needed to run the resampler."""
239234
assert self._resampler is not None
@@ -247,7 +242,9 @@ async def sink_buffer(sample: Sample[Quantity]) -> None:
247242
self._resampler.add_timeseries(
248243
"avg", resampler_channel.new_receiver(), sink_buffer
249244
)
250-
self._resampler_task = asyncio.create_task(self._resampler.resample())
245+
self._tasks.add(
246+
asyncio.create_task(self._resampler.resample(), name="resample")
247+
)
251248

252249
def __len__(self) -> int:
253250
"""

src/frequenz/sdk/timeseries/_periodic_feature_extractor.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -82,28 +82,27 @@ class PeriodicFeatureExtractor:
8282
from frequenz.sdk import microgrid
8383
from datetime import datetime, timedelta, timezone
8484
85-
moving_window = MovingWindow(
85+
async with MovingWindow(
8686
size=timedelta(days=35),
8787
resampled_data_recv=microgrid.logical_meter().grid_power.new_receiver(),
8888
input_sampling_period=timedelta(seconds=1),
89-
)
90-
91-
feature_extractor = PeriodicFeatureExtractor(
92-
moving_window = moving_window,
93-
period=timedelta(days=7),
94-
)
89+
) as moving_window:
90+
feature_extractor = PeriodicFeatureExtractor(
91+
moving_window=moving_window,
92+
period=timedelta(days=7),
93+
)
9594
96-
now = datetime.now(timezone.utc)
95+
now = datetime.now(timezone.utc)
9796
98-
# create a daily weighted average for the next 24h
99-
avg_24h = feature_extractor.avg(
100-
now,
101-
now + timedelta(hours=24),
102-
weights=[0.1, 0.2, 0.3, 0.4]
103-
)
97+
# create a daily weighted average for the next 24h
98+
avg_24h = feature_extractor.avg(
99+
now,
100+
now + timedelta(hours=24),
101+
weights=[0.1, 0.2, 0.3, 0.4]
102+
)
104103
105-
# create a daily average for Thursday March 23 2023
106-
th_avg_24h = feature_extractor.avg(datetime(2023, 3, 23), datetime(2023, 3, 24))
104+
# create a daily average for Thursday March 23 2023
105+
th_avg_24h = feature_extractor.avg(datetime(2023, 3, 23), datetime(2023, 3, 24))
107106
```
108107
"""
109108

tests/timeseries/test_moving_window.py

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,17 @@ def init_moving_window(
7575
async def test_access_window_by_index() -> None:
7676
"""Test indexing a window by integer index"""
7777
window, sender = init_moving_window(timedelta(seconds=1))
78-
await push_logical_meter_data(sender, [1])
79-
assert np.array_equal(window[0], 1.0)
78+
async with window:
79+
await push_logical_meter_data(sender, [1])
80+
assert np.array_equal(window[0], 1.0)
8081

8182

8283
async def test_access_window_by_timestamp() -> None:
8384
"""Test indexing a window by timestamp"""
8485
window, sender = init_moving_window(timedelta(seconds=1))
85-
await push_logical_meter_data(sender, [1])
86-
assert np.array_equal(window[UNIX_EPOCH], 1.0)
86+
async with window:
87+
await push_logical_meter_data(sender, [1])
88+
assert np.array_equal(window[UNIX_EPOCH], 1.0)
8789

8890

8991
async def test_access_window_by_int_slice() -> None:
@@ -94,35 +96,39 @@ async def test_access_window_by_int_slice() -> None:
9496
since the push_lm_data function is starting with the same initial timestamp.
9597
"""
9698
window, sender = init_moving_window(timedelta(seconds=14))
97-
await push_logical_meter_data(sender, range(0, 5))
98-
assert np.array_equal(window[3:5], np.array([3.0, 4.0]))
99+
async with window:
100+
await push_logical_meter_data(sender, range(0, 5))
101+
assert np.array_equal(window[3:5], np.array([3.0, 4.0]))
99102

100-
data = [1, 2, 2.5, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1]
101-
await push_logical_meter_data(sender, data)
102-
assert np.array_equal(window[5:14], np.array(data[5:14]))
103+
data = [1, 2, 2.5, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1]
104+
await push_logical_meter_data(sender, data)
105+
assert np.array_equal(window[5:14], np.array(data[5:14]))
103106

104107

105108
async def test_access_window_by_ts_slice() -> None:
106109
"""Test accessing a subwindow with a timestamp slice"""
107110
window, sender = init_moving_window(timedelta(seconds=5))
108-
await push_logical_meter_data(sender, range(0, 5))
109-
time_start = UNIX_EPOCH + timedelta(seconds=3)
110-
time_end = time_start + timedelta(seconds=2)
111-
assert np.array_equal(window[time_start:time_end], np.array([3.0, 4.0])) # type: ignore
111+
async with window:
112+
await push_logical_meter_data(sender, range(0, 5))
113+
time_start = UNIX_EPOCH + timedelta(seconds=3)
114+
time_end = time_start + timedelta(seconds=2)
115+
assert np.array_equal(window[time_start:time_end], np.array([3.0, 4.0])) # type: ignore
112116

113117

114118
async def test_access_empty_window() -> None:
115119
"""Test accessing an empty window, should throw IndexError"""
116120
window, _ = init_moving_window(timedelta(seconds=5))
117-
with pytest.raises(IndexError, match=r"^The buffer is empty\.$"):
118-
_ = window[42]
121+
async with window:
122+
with pytest.raises(IndexError, match=r"^The buffer is empty\.$"):
123+
_ = window[42]
119124

120125

121126
async def test_window_size() -> None:
122127
"""Test the size of the window."""
123128
window, sender = init_moving_window(timedelta(seconds=5))
124-
await push_logical_meter_data(sender, range(0, 20))
125-
assert len(window) == 5
129+
async with window:
130+
await push_logical_meter_data(sender, range(0, 20))
131+
assert len(window) == 5
126132

127133

128134
# pylint: disable=redefined-outer-name
@@ -136,21 +142,20 @@ async def test_resampling_window(fake_time: time_machine.Coordinates) -> None:
136142
output_sampling = timedelta(seconds=2)
137143
resampler_config = ResamplerConfig(resampling_period=output_sampling)
138144

139-
window = MovingWindow(
145+
async with MovingWindow(
140146
size=window_size,
141147
resampled_data_recv=channel.new_receiver(),
142148
input_sampling_period=input_sampling,
143149
resampler_config=resampler_config,
144-
)
145-
146-
stream_values = [4.0, 8.0, 2.0, 6.0, 5.0] * 100
147-
for value in stream_values:
148-
timestamp = datetime.now(tz=timezone.utc)
149-
sample = Sample(timestamp, Quantity(float(value)))
150-
await sender.send(sample)
151-
await asyncio.sleep(0.1)
152-
fake_time.shift(0.1)
153-
154-
assert len(window) == window_size / output_sampling
155-
for value in window: # type: ignore
156-
assert 4.9 < value < 5.1
150+
) as window:
151+
stream_values = [4.0, 8.0, 2.0, 6.0, 5.0] * 100
152+
for value in stream_values:
153+
timestamp = datetime.now(tz=timezone.utc)
154+
sample = Sample(timestamp, Quantity(float(value)))
155+
await sender.send(sample)
156+
await asyncio.sleep(0.1)
157+
fake_time.shift(0.1)
158+
159+
assert len(window) == window_size / output_sampling
160+
for value in window: # type: ignore
161+
assert 4.9 < value < 5.1

0 commit comments

Comments
 (0)