Skip to content

Commit 85dda90

Browse files
committed
[MovingWindow] Add a wait_for_samples method
Signed-off-by: Sahas Subramanian <[email protected]>
1 parent a78106a commit 85dda90

File tree

2 files changed

+198
-0
lines changed

2 files changed

+198
-0
lines changed

src/frequenz/sdk/timeseries/_moving_window.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ def __init__( # pylint: disable=too-many-arguments
190190
align_to=align_to,
191191
)
192192

193+
self._condition_new_sample = asyncio.Condition()
194+
193195
def start(self) -> None:
194196
"""Start the MovingWindow.
195197
@@ -318,6 +320,44 @@ def window(
318320
start, end, force_copy=force_copy, fill_value=fill_value
319321
)
320322

323+
async def wait_for_samples(self, n: int) -> None:
324+
"""Wait until the next `n` samples are available in the MovingWindow.
325+
326+
This function returns after `n` new samples are available in the MovingWindow,
327+
without considering whether the new samples are valid. The validity of the
328+
samples can be verified by calling the
329+
[`count_valid`][frequenz.sdk.timeseries.MovingWindow.count_valid] method.
330+
331+
Args:
332+
n: The number of samples to wait for.
333+
334+
Raises:
335+
ValueError: If `n` is less than or equal to 0 or greater than the capacity
336+
of the MovingWindow.
337+
"""
338+
if n == 0:
339+
return
340+
if n < 0:
341+
raise ValueError("The number of samples to wait for must be 0 or greater.")
342+
if n > self.capacity:
343+
raise ValueError(
344+
"The number of samples to wait for must be less than or equal to the "
345+
+ f"capacity of the MovingWindow ({self.capacity})."
346+
)
347+
start_timestamp = (
348+
# Start from the next expected timestamp.
349+
self.newest_timestamp + self.sampling_period
350+
if self.newest_timestamp is not None
351+
else None
352+
)
353+
while True:
354+
async with self._condition_new_sample:
355+
# Every time a new sample is received, this condition gets notified and
356+
# will wake up.
357+
_ = await self._condition_new_sample.wait()
358+
if self.count_covered(since=start_timestamp) >= n:
359+
return
360+
321361
async def _run_impl(self) -> None:
322362
"""Awaits samples from the receiver and updates the underlying ring buffer.
323363
@@ -331,6 +371,9 @@ async def _run_impl(self) -> None:
331371
await self._resampler_sender.send(sample)
332372
else:
333373
self._buffer.update(sample)
374+
async with self._condition_new_sample:
375+
# Wake up all coroutines waiting for new samples.
376+
self._condition_new_sample.notify_all()
334377

335378
except asyncio.CancelledError:
336379
_logger.info("MovingWindow task has been cancelled.")
@@ -344,6 +387,9 @@ def _configure_resampler(self) -> None:
344387

345388
async def sink_buffer(sample: Sample[Quantity]) -> None:
346389
self._buffer.update(sample)
390+
async with self._condition_new_sample:
391+
# Wake up all coroutines waiting for new samples.
392+
self._condition_new_sample.notify_all()
347393

348394
resampler_channel = Broadcast[Sample[Quantity]](name="average")
349395
self._resampler_sender = resampler_channel.new_sender()

tests/timeseries/test_moving_window.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""Tests for the moving window."""
55

66
import asyncio
7+
import re
78
from collections.abc import Sequence
89
from datetime import datetime, timedelta, timezone
910

@@ -29,6 +30,7 @@ async def push_logical_meter_data(
2930
sender: Sender[Sample[Quantity]],
3031
test_seq: Sequence[float | None],
3132
start_ts: datetime = UNIX_EPOCH,
33+
fake_time: time_machine.Coordinates | None = None,
3234
) -> None:
3335
"""Push data in the passed sender to mock `LogicalMeter` behaviour.
3436
@@ -38,23 +40,29 @@ async def push_logical_meter_data(
3840
sender: Sender for pushing resampled samples to the `MovingWindow`.
3941
test_seq: The Sequence that is pushed into the `MovingWindow`.
4042
start_ts: The start timestamp of the `MovingWindow`.
43+
fake_time: The fake time object to shift the time.
4144
"""
4245
for i, j in zip(test_seq, range(0, len(test_seq))):
4346
timestamp = start_ts + timedelta(seconds=j)
4447
await sender.send(
4548
Sample(timestamp, Quantity(float(i)) if i is not None else None)
4649
)
50+
if fake_time is not None:
51+
await asyncio.sleep(1.0)
52+
fake_time.shift(1)
4753

4854
await asyncio.sleep(0.0)
4955

5056

5157
def init_moving_window(
5258
size: timedelta,
59+
resampler_config: ResamplerConfig | None = None,
5360
) -> tuple[MovingWindow, Sender[Sample[Quantity]]]:
5461
"""Initialize the moving window with given shape.
5562
5663
Args:
5764
size: The size of the `MovingWindow`
65+
resampler_config: The resampler configuration.
5866
5967
Returns:
6068
tuple[MovingWindow, Sender[Sample]]: A pair of sender and `MovingWindow`.
@@ -65,6 +73,7 @@ def init_moving_window(
6573
size=size,
6674
resampled_data_recv=lm_chan.new_receiver(),
6775
input_sampling_period=timedelta(seconds=1),
76+
resampler_config=resampler_config,
6877
)
6978
return window, lm_tx
7079

@@ -363,6 +372,149 @@ def assert_valid_and_covered_counts(
363372
)
364373

365374

375+
async def test_wait_for_samples() -> None:
376+
"""Test waiting for samples in the window."""
377+
window, sender = init_moving_window(timedelta(seconds=10))
378+
async with window:
379+
task = asyncio.create_task(window.wait_for_samples(5))
380+
await asyncio.sleep(0)
381+
assert not task.done()
382+
await push_logical_meter_data(sender, range(0, 5))
383+
await asyncio.sleep(0)
384+
# After pushing 5 values, the `wait_for_samples` task should be done.
385+
assert task.done()
386+
387+
task = asyncio.create_task(window.wait_for_samples(5))
388+
await asyncio.sleep(0)
389+
await push_logical_meter_data(
390+
sender, [1, 2, 3, 4], start_ts=UNIX_EPOCH + timedelta(seconds=5)
391+
)
392+
await asyncio.sleep(0)
393+
# The task should not be done yet, since we have only pushed 4 values.
394+
assert not task.done()
395+
396+
await push_logical_meter_data(
397+
sender, [1], start_ts=UNIX_EPOCH + timedelta(seconds=9)
398+
)
399+
await asyncio.sleep(0)
400+
# After pushing the last value, the task should be done.
401+
assert task.done()
402+
403+
task = asyncio.create_task(window.wait_for_samples(-1))
404+
with pytest.raises(
405+
ValueError,
406+
match=re.escape("The number of samples to wait for must be 0 or greater."),
407+
):
408+
await task
409+
410+
task = asyncio.create_task(window.wait_for_samples(20))
411+
with pytest.raises(
412+
ValueError,
413+
match=re.escape(
414+
"The number of samples to wait for must be less than or equal to the "
415+
+ "capacity of the MovingWindow (10)."
416+
),
417+
):
418+
await task
419+
420+
task = asyncio.create_task(window.wait_for_samples(4))
421+
await asyncio.sleep(0)
422+
await push_logical_meter_data(
423+
sender, range(0, 10), start_ts=UNIX_EPOCH + timedelta(seconds=10)
424+
)
425+
await asyncio.sleep(0)
426+
assert task.done()
427+
428+
task = asyncio.create_task(window.wait_for_samples(10))
429+
await asyncio.sleep(0)
430+
await push_logical_meter_data(
431+
sender, range(0, 5), start_ts=UNIX_EPOCH + timedelta(seconds=20)
432+
)
433+
await asyncio.sleep(0)
434+
assert not task.done()
435+
436+
await push_logical_meter_data(
437+
sender, range(10, 15), start_ts=UNIX_EPOCH + timedelta(seconds=25)
438+
)
439+
await asyncio.sleep(0)
440+
assert task.done()
441+
442+
task = asyncio.create_task(window.wait_for_samples(5))
443+
await asyncio.sleep(0)
444+
await push_logical_meter_data(
445+
sender, [1, 2, None, 4, None], start_ts=UNIX_EPOCH + timedelta(seconds=30)
446+
)
447+
await asyncio.sleep(0)
448+
# `None` values *are* counted towards the number of samples to wait for.
449+
assert task.done()
450+
451+
452+
async def test_wait_for_samples_with_resampling(
453+
fake_time: time_machine.Coordinates,
454+
) -> None:
455+
"""Test waiting for samples in a moving window with resampling."""
456+
window, sender = init_moving_window(
457+
timedelta(seconds=20), ResamplerConfig(resampling_period=timedelta(seconds=2))
458+
)
459+
async with window:
460+
task = asyncio.create_task(window.wait_for_samples(3))
461+
await asyncio.sleep(0)
462+
assert not task.done()
463+
await push_logical_meter_data(sender, range(0, 7), fake_time=fake_time)
464+
assert task.done()
465+
466+
task = asyncio.create_task(window.wait_for_samples(10))
467+
await push_logical_meter_data(
468+
sender,
469+
range(0, 11),
470+
fake_time=fake_time,
471+
start_ts=UNIX_EPOCH + timedelta(seconds=7),
472+
)
473+
assert window.count_covered() == 8
474+
assert not task.done()
475+
476+
await push_logical_meter_data(
477+
sender,
478+
range(0, 5),
479+
fake_time=fake_time,
480+
start_ts=UNIX_EPOCH + timedelta(seconds=18),
481+
)
482+
assert window.count_covered() == 10
483+
assert not task.done()
484+
485+
await push_logical_meter_data(
486+
sender,
487+
range(0, 6),
488+
fake_time=fake_time,
489+
start_ts=UNIX_EPOCH + timedelta(seconds=23),
490+
)
491+
assert window.count_covered() == 10
492+
assert window.count_valid() == 10
493+
assert task.done()
494+
495+
task = asyncio.create_task(window.wait_for_samples(5))
496+
await push_logical_meter_data(
497+
sender,
498+
[1, 2, None, None, None, None, None, None, None, None],
499+
fake_time=fake_time,
500+
start_ts=UNIX_EPOCH + timedelta(seconds=29),
501+
)
502+
assert window.count_covered() == 10
503+
assert window.count_valid() == 8
504+
assert task.done()
505+
506+
task = asyncio.create_task(window.wait_for_samples(5))
507+
await push_logical_meter_data(
508+
sender,
509+
[None, 4, None, None, None, None, None, None, None, 5],
510+
fake_time=fake_time,
511+
start_ts=UNIX_EPOCH + timedelta(seconds=39),
512+
)
513+
assert window.count_covered() == 10
514+
assert window.count_valid() == 7
515+
assert task.done()
516+
517+
366518
# pylint: disable=redefined-outer-name
367519
async def test_resampling_window(fake_time: time_machine.Coordinates) -> None:
368520
"""Test resampling in MovingWindow."""

0 commit comments

Comments
 (0)