Skip to content

Commit bceccfc

Browse files
authored
Add a wait_for_samples method to the MovingWindow (frequenz-floss#1159)
Closes frequenz-floss#967
2 parents f208b10 + 32e27cb commit bceccfc

File tree

4 files changed

+432
-30
lines changed

4 files changed

+432
-30
lines changed

RELEASE_NOTES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
## New Features
1212

13-
<!-- Here goes the main new features and examples or instructions on how to use them -->
13+
- The `MovingWindow` now has an async `wait_for_samples` method that waits for a given number of samples to become available in the moving window and then returns.
1414

1515
## Bug Fixes
1616

src/frequenz/sdk/timeseries/_moving_window.py

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ def __init__( # pylint: disable=too-many-arguments
190190
align_to=align_to,
191191
)
192192

193+
self._condition_new_sample = asyncio.Condition()
194+
193195
def start(self) -> None:
194196
"""Start the MovingWindow.
195197
@@ -318,6 +320,44 @@ def window(
318320
start, end, force_copy=force_copy, fill_value=fill_value
319321
)
320322

323+
async def wait_for_samples(self, n: int) -> None:
324+
"""Wait until the next `n` samples are available in the MovingWindow.
325+
326+
This function returns after `n` new samples are available in the MovingWindow,
327+
without considering whether the new samples are valid. The validity of the
328+
samples can be verified by calling the
329+
[`count_valid`][frequenz.sdk.timeseries.MovingWindow.count_valid] method.
330+
331+
Args:
332+
n: The number of samples to wait for.
333+
334+
Raises:
335+
ValueError: If `n` is less than or equal to 0 or greater than the capacity
336+
of the MovingWindow.
337+
"""
338+
if n == 0:
339+
return
340+
if n < 0:
341+
raise ValueError("The number of samples to wait for must be 0 or greater.")
342+
if n > self.capacity:
343+
raise ValueError(
344+
"The number of samples to wait for must be less than or equal to the "
345+
+ f"capacity of the MovingWindow ({self.capacity})."
346+
)
347+
start_timestamp = (
348+
# Start from the next expected timestamp.
349+
self.newest_timestamp + self.sampling_period
350+
if self.newest_timestamp is not None
351+
else None
352+
)
353+
while True:
354+
async with self._condition_new_sample:
355+
# Every time a new sample is received, this condition gets notified and
356+
# will wake up.
357+
_ = await self._condition_new_sample.wait()
358+
if self.count_covered(since=start_timestamp) >= n:
359+
return
360+
321361
async def _run_impl(self) -> None:
322362
"""Awaits samples from the receiver and updates the underlying ring buffer.
323363
@@ -331,6 +371,9 @@ async def _run_impl(self) -> None:
331371
await self._resampler_sender.send(sample)
332372
else:
333373
self._buffer.update(sample)
374+
async with self._condition_new_sample:
375+
# Wake up all coroutines waiting for new samples.
376+
self._condition_new_sample.notify_all()
334377

335378
except asyncio.CancelledError:
336379
_logger.info("MovingWindow task has been cancelled.")
@@ -343,8 +386,10 @@ def _configure_resampler(self) -> None:
343386
assert self._resampler is not None
344387

345388
async def sink_buffer(sample: Sample[Quantity]) -> None:
346-
if sample.value is not None:
347-
self._buffer.update(sample)
389+
self._buffer.update(sample)
390+
async with self._condition_new_sample:
391+
# Wake up all coroutines waiting for new samples.
392+
self._condition_new_sample.notify_all()
348393

349394
resampler_channel = Broadcast[Sample[Quantity]](name="average")
350395
self._resampler_sender = resampler_channel.new_sender()
@@ -355,23 +400,44 @@ async def sink_buffer(sample: Sample[Quantity]) -> None:
355400
asyncio.create_task(self._resampler.resample(), name="resample")
356401
)
357402

358-
def count_valid(self) -> int:
359-
"""
360-
Count the number of valid samples in this `MovingWindow`.
403+
def count_valid(
404+
self, *, since: datetime | None = None, until: datetime | None = None
405+
) -> int:
406+
"""Count the number of valid samples in this `MovingWindow`.
407+
408+
If `since` and `until` are provided, the count is limited to the samples between
409+
(and including) the given timestamps.
410+
411+
Args:
412+
since: The timestamp from which to start counting. If `None`, the oldest
413+
timestamp of the buffer is used.
414+
until: The timestamp until (and including) which to count. If `None`, the
415+
newest timestamp of the buffer is used.
361416
362417
Returns:
363418
The number of valid samples in this `MovingWindow`.
364419
"""
365-
return self._buffer.count_valid()
420+
return self._buffer.count_valid(since=since, until=until)
366421

367-
def count_covered(self) -> int:
422+
def count_covered(
423+
self, *, since: datetime | None = None, until: datetime | None = None
424+
) -> int:
368425
"""Count the number of samples that are covered by the oldest and newest valid samples.
369426
427+
If `since` and `until` are provided, the count is limited to the samples between
428+
(and including) the given timestamps.
429+
430+
Args:
431+
since: The timestamp from which to start counting. If `None`, the oldest
432+
timestamp of the buffer is used.
433+
until: The timestamp until (and including) which to count. If `None`, the
434+
newest timestamp of the buffer is used.
435+
370436
Returns:
371437
The count of samples between the oldest and newest (inclusive) valid samples
372438
or 0 if there are is no time range covered.
373439
"""
374-
return self._buffer.count_covered()
440+
return self._buffer.count_covered(since=since, until=until)
375441

376442
@overload
377443
def __getitem__(self, key: SupportsIndex) -> float:

src/frequenz/sdk/timeseries/_ringbuffer/buffer.py

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -651,9 +651,20 @@ def __getitem__(self, index_or_slice: SupportsIndex | slice) -> float | FloatArr
651651
"""
652652
return self._buffer.__getitem__(index_or_slice)
653653

654-
def _covered_time_range(self) -> timedelta:
654+
def _covered_time_range(
655+
self, since: datetime | None = None, until: datetime | None = None
656+
) -> timedelta:
655657
"""Return the time range that is covered by the oldest and newest valid samples.
656658
659+
If `since` and `until` are provided, the time range is limited to the items
660+
between (and including) the given timestamps.
661+
662+
Args:
663+
since: The timestamp from which to start counting. If `None`, the oldest
664+
timestamp in the buffer is used.
665+
until: The timestamp until (and including) which to count. If `None`, the
666+
newest timestamp in the buffer is used.
667+
657668
Returns:
658669
The time range between the oldest and newest valid samples or 0 if
659670
there are is no time range covered.
@@ -664,45 +675,82 @@ def _covered_time_range(self) -> timedelta:
664675
assert (
665676
self.newest_timestamp is not None
666677
), "Newest timestamp cannot be None here."
667-
return self.newest_timestamp - self.oldest_timestamp + self._sampling_period
668678

669-
def count_covered(self) -> int:
679+
if since is None or since < self.oldest_timestamp:
680+
since = self.oldest_timestamp
681+
if until is None or until > self.newest_timestamp:
682+
until = self.newest_timestamp
683+
684+
if until < since:
685+
return timedelta(0)
686+
687+
return until - since + self._sampling_period
688+
689+
def count_covered(
690+
self, *, since: datetime | None = None, until: datetime | None = None
691+
) -> int:
670692
"""Count the number of samples that are covered by the oldest and newest valid samples.
671693
694+
If `since` and `until` are provided, the count is limited to the items between
695+
(and including) the given timestamps.
696+
697+
Args:
698+
since: The timestamp from which to start counting. If `None`, the oldest
699+
timestamp in the buffer is used.
700+
until: The timestamp until (and including) which to count. If `None`, the
701+
newest timestamp in the buffer is used.
702+
672703
Returns:
673704
The count of samples between the oldest and newest (inclusive) valid samples
674705
or 0 if there are is no time range covered.
675706
"""
676707
return int(
677-
self._covered_time_range().total_seconds()
708+
self._covered_time_range(since, until).total_seconds()
678709
// self._sampling_period.total_seconds()
679710
)
680711

681-
def count_valid(self) -> int:
682-
"""Count the number of valid items that this buffer currently holds.
712+
def count_valid(
713+
self, *, since: datetime | None = None, until: datetime | None = None
714+
) -> int:
715+
"""Count the number of valid items in this buffer.
716+
717+
If `since` and `until` are provided, the count is limited to the items between
718+
(and including) the given timestamps.
719+
720+
Args:
721+
since: The timestamp from which to start counting. If `None`, the oldest
722+
timestamp in the buffer is used.
723+
until: The timestamp until (and including) which to count. If `None`, the
724+
newest timestamp in the buffer is used.
683725
684726
Returns:
685727
The number of valid items in this buffer.
686728
"""
687-
if self._timestamp_newest == self._TIMESTAMP_MIN:
729+
if since is None or since < self._timestamp_oldest:
730+
since = self._timestamp_oldest
731+
if until is None or until > self._timestamp_newest:
732+
until = self._timestamp_newest
733+
734+
if until == self._TIMESTAMP_MIN or until < since:
688735
return 0
689736

690737
# Sum of all elements in the gap ranges
691738
sum_missing_entries = max(
692739
0,
693740
sum(
694741
(
695-
gap.end
742+
min(gap.end, until + self._sampling_period)
696743
# Don't look further back than oldest timestamp
697-
- max(gap.start, self._timestamp_oldest)
744+
- max(gap.start, since)
698745
)
699746
// self._sampling_period
700747
for gap in self._gaps
748+
if gap.start <= until and gap.end >= since
701749
),
702750
)
703751

704-
start_pos = self.to_internal_index(self._timestamp_oldest)
705-
end_pos = self.to_internal_index(self._timestamp_newest)
752+
start_pos = self.to_internal_index(since)
753+
end_pos = self.to_internal_index(until)
706754

707755
if end_pos < start_pos:
708756
return len(self._buffer) - start_pos + end_pos + 1 - sum_missing_entries

0 commit comments

Comments
 (0)