Skip to content

Commit d73eaa3

Browse files
committed
[MovingWindow] Add a wait_for_samples method
Signed-off-by: Sahas Subramanian <[email protected]>
1 parent 800ca03 commit d73eaa3

File tree

2 files changed

+102
-0
lines changed

2 files changed

+102
-0
lines changed

src/frequenz/sdk/timeseries/_moving_window.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ def __init__( # pylint: disable=too-many-arguments
190190
align_to=align_to,
191191
)
192192

193+
self._condition_new_sample = asyncio.Condition()
194+
193195
def start(self) -> None:
194196
"""Start the MovingWindow.
195197
@@ -318,6 +320,34 @@ def window(
318320
start, end, force_copy=force_copy, fill_value=fill_value
319321
)
320322

323+
async def wait_for_samples(self, n: int) -> None:
324+
"""Wait until the next `n` samples are available in the MovingWindow.
325+
326+
Args:
327+
n: The number of samples to wait for.
328+
329+
Raises:
330+
ValueError: If `n` is less than or equal to 0 or greater than the capacity
331+
of the MovingWindow.
332+
"""
333+
if n <= 0:
334+
raise ValueError(
335+
"The number of samples to wait for must be greater than 0."
336+
)
337+
if n > self.capacity:
338+
raise ValueError(
339+
"The number of samples to wait for must be less than or equal to the "
340+
+ f"capacity of the MovingWindow ({self.capacity})."
341+
)
342+
start_timestamp = self.newest_timestamp
343+
if n < self.capacity:
344+
n += self.count_valid(since=start_timestamp)
345+
while True:
346+
async with self._condition_new_sample:
347+
_ = await self._condition_new_sample.wait()
348+
if self.count_valid(since=start_timestamp) >= n:
349+
return
350+
321351
async def _run_impl(self) -> None:
322352
"""Awaits samples from the receiver and updates the underlying ring buffer.
323353
@@ -330,6 +360,8 @@ async def _run_impl(self) -> None:
330360
if self._resampler and self._resampler_sender:
331361
await self._resampler_sender.send(sample)
332362
else:
363+
async with self._condition_new_sample:
364+
self._condition_new_sample.notify_all()
333365
self._buffer.update(sample)
334366

335367
except asyncio.CancelledError:
@@ -344,6 +376,8 @@ def _configure_resampler(self) -> None:
344376

345377
async def sink_buffer(sample: Sample[Quantity]) -> None:
346378
if sample.value is not None:
379+
async with self._condition_new_sample:
380+
self._condition_new_sample.notify_all()
347381
self._buffer.update(sample)
348382

349383
resampler_channel = Broadcast[Sample[Quantity]](name="average")

tests/timeseries/test_moving_window.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""Tests for the moving window."""
55

66
import asyncio
7+
import re
78
from collections.abc import Sequence
89
from datetime import datetime, timedelta, timezone
910

@@ -299,6 +300,73 @@ async def test_window_size() -> None:
299300
)
300301

301302

303+
async def test_wait_for_samples() -> None:
304+
"""Test waiting for samples in the window."""
305+
window, sender = init_moving_window(timedelta(seconds=10))
306+
async with window:
307+
task = asyncio.create_task(window.wait_for_samples(5))
308+
await asyncio.sleep(0)
309+
assert not task.done()
310+
await push_logical_meter_data(sender, range(0, 5))
311+
await asyncio.sleep(0)
312+
assert task.done()
313+
314+
task = asyncio.create_task(window.wait_for_samples(5))
315+
await asyncio.sleep(0)
316+
await push_logical_meter_data(
317+
sender, [1, 2, 3, 4], start_ts=UNIX_EPOCH + timedelta(seconds=5)
318+
)
319+
await asyncio.sleep(0)
320+
assert not task.done()
321+
322+
await push_logical_meter_data(
323+
sender, [1], start_ts=UNIX_EPOCH + timedelta(seconds=9)
324+
)
325+
await asyncio.sleep(0)
326+
assert task.done()
327+
328+
task = asyncio.create_task(window.wait_for_samples(-1))
329+
with pytest.raises(
330+
ValueError,
331+
match=re.escape(
332+
"The number of samples to wait for must be greater than 0."
333+
),
334+
):
335+
await task
336+
337+
task = asyncio.create_task(window.wait_for_samples(20))
338+
with pytest.raises(
339+
ValueError,
340+
match=re.escape(
341+
"The number of samples to wait for must be less than or equal to the "
342+
+ "capacity of the MovingWindow (10)."
343+
),
344+
):
345+
await task
346+
347+
task = asyncio.create_task(window.wait_for_samples(4))
348+
await asyncio.sleep(0)
349+
await push_logical_meter_data(
350+
sender, range(0, 10), start_ts=UNIX_EPOCH + timedelta(seconds=10)
351+
)
352+
await asyncio.sleep(0)
353+
assert task.done()
354+
355+
task = asyncio.create_task(window.wait_for_samples(10))
356+
await asyncio.sleep(0)
357+
await push_logical_meter_data(
358+
sender, range(0, 5), start_ts=UNIX_EPOCH + timedelta(seconds=20)
359+
)
360+
await asyncio.sleep(0)
361+
assert not task.done()
362+
363+
await push_logical_meter_data(
364+
sender, range(10, 15), start_ts=UNIX_EPOCH + timedelta(seconds=25)
365+
)
366+
await asyncio.sleep(0)
367+
assert task.done()
368+
369+
302370
# pylint: disable=redefined-outer-name
303371
async def test_resampling_window(fake_time: time_machine.Coordinates) -> None:
304372
"""Test resampling in MovingWindow."""

0 commit comments

Comments
 (0)