Skip to content

Commit 9c12e6c

Browse files
committed
Support int indices in ring buffer window method
Using integer-based indices allows for selecting positions within the window without requiring knowledge of the specific timestamps. Signed-off-by: cwasicki <[email protected]>
1 parent 3cc23db commit 9c12e6c

File tree

3 files changed

+82
-20
lines changed

3 files changed

+82
-20
lines changed

src/frequenz/sdk/timeseries/_ringbuffer/buffer.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,26 @@ def index_to_datetime(self, index: int) -> datetime | None:
280280
)
281281
return ref_ts + index * self._sampling_period
282282

283+
def _index_to_slice(
284+
self, start: int | None, end: int | None = None
285+
) -> tuple[int, int]:
286+
"""Project the given indices via slice onto the covered range.
287+
288+
Args:
289+
start: Start index.
290+
end: End index. Optional, defaults to None.
291+
292+
Returns:
293+
tuple of start and end indices on the range currently covered by the buffer.
294+
"""
295+
return slice(start, end).indices(self.count_covered)[:2]
296+
283297
def window(
284-
self, start: datetime, end: datetime, *, force_copy: bool = True
298+
self,
299+
start: datetime | int | None,
300+
end: datetime | int | None,
301+
*,
302+
force_copy: bool = True,
285303
) -> FloatArray:
286304
"""Request a copy or view on the data between start timestamp and end timestamp.
287305
@@ -306,17 +324,32 @@ def window(
306324
copy of the data.
307325
308326
Raises:
309-
IndexError: When requesting a window with invalid timestamps.
327+
IndexError: When start and end are not both datetime or index.
310328
311329
Returns:
312330
The requested window
313331
"""
314-
if start > end:
332+
if self.count_covered == 0:
333+
return np.array([]) if isinstance(self._buffer, np.ndarray) else []
334+
335+
# If both are indices or None convert to datetime
336+
if not isinstance(start, datetime) and not isinstance(end, datetime):
337+
start, end = self._index_to_slice(start, end)
338+
start = self.index_to_datetime(start)
339+
end = self.index_to_datetime(end)
340+
341+
# Here we should have both as datetime
342+
if not (isinstance(start, datetime) and isinstance(end, datetime)):
315343
raise IndexError(
316-
f"end parameter {end} has to predate start parameter {start}"
344+
f"start ({start}) and end ({end}) must both be either datetime or index."
317345
)
318346

319-
if start == end:
347+
# Ensure that the window is within the bounds of the buffer
348+
assert self.oldest_timestamp is not None and self.newest_timestamp is not None
349+
start = max(start, self.oldest_timestamp)
350+
end = min(end, self.newest_timestamp + self._sampling_period)
351+
352+
if start >= end:
320353
return np.array([]) if isinstance(self._buffer, np.ndarray) else []
321354

322355
start_pos = self.to_internal_index(start)

tests/timeseries/test_moving_window.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,10 @@ async def test_access_window_by_ts_slice() -> None:
121121
assert np.array_equal(window[time_start:time_end], np.array([3.0, 4.0])) # type: ignore
122122
assert np.array_equal(window.window(dt(3), dt(5)), np.array([3.0, 4.0]))
123123
assert np.array_equal(window.window(dt(3), dt(3)), np.array([]))
124-
# Window only supports slicing with ascending indices within allowed range
125-
with pytest.raises(IndexError):
126-
window.window(dt(3), dt(1))
127-
with pytest.raises(IndexError):
128-
window.window(dt(3), dt(6))
129-
with pytest.raises(IndexError):
130-
window.window(dt(-1), dt(5))
124+
# Window also supports slicing with indices outside allowed range
125+
assert np.array_equal(window.window(dt(3), dt(1)), np.array([]))
126+
assert np.array_equal(window.window(dt(3), dt(6)), np.array([3, 4]))
127+
assert np.array_equal(window.window(dt(-1), dt(5)), np.array([0, 1, 2, 3, 4]))
131128

132129

133130
async def test_access_empty_window() -> None:

tests/timeseries/test_ringbuffer.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,11 @@ def test_timestamp_ringbuffer_gaps(
137137
Sample(datetime.fromtimestamp(500 + size, tz=timezone.utc), Quantity(9999))
138138
)
139139

140-
# Expect exception for the same window
141-
with pytest.raises(IndexError):
142-
buffer.window(
143-
datetime.fromtimestamp(200, tz=timezone.utc),
144-
datetime.fromtimestamp(202, tz=timezone.utc),
145-
)
140+
# Allow still to request old (empty) window
141+
buffer.window(
142+
datetime.fromtimestamp(200, tz=timezone.utc),
143+
datetime.fromtimestamp(202, tz=timezone.utc),
144+
)
146145

147146
# Receive new window without exception
148147
buffer.window(
@@ -524,8 +523,8 @@ def get_orb(data: FloatArray) -> OrderedRingBuffer[FloatArray]:
524523
return buffer
525524

526525

527-
def test_window() -> None:
528-
"""Test the window function."""
526+
def test_window_datetime() -> None:
527+
"""Test the window function with datetime."""
529528
buffer = get_orb(np.array([0, None, 2, 3, 4]))
530529
win = buffer.window(dt(0), dt(3), force_copy=False)
531530
assert [0, np.nan, 2] == list(win)
@@ -543,6 +542,39 @@ def test_window() -> None:
543542
assert [] == buffer.window(dt(1), dt(1))
544543

545544

545+
def test_window_index() -> None:
546+
"""Test the window function with index."""
547+
buffer = get_orb([0.0, 1.0, 2.0, 3.0, 4.0])
548+
assert [0, 1, 2] == buffer.window(0, 3)
549+
assert [0, 1, 2, 3, 4] == buffer.window(0, 5)
550+
assert [0, 1, 2, 3, 4] == buffer.window(0, 99)
551+
assert [2, 3] == buffer.window(-3, -1)
552+
assert [2, 3, 4] == buffer.window(-3, 5)
553+
assert [0, 1, 2, 3] == buffer.window(-5, -1)
554+
assert [0, 1, 2, 3, 4] == buffer.window(-99, None)
555+
assert [0, 1, 2, 3, 4] == buffer.window(None, 99)
556+
# start >= end
557+
assert [] == buffer.window(0, 0)
558+
assert [] == buffer.window(-5, 0)
559+
assert [] == buffer.window(1, 0)
560+
assert [] == buffer.window(-1, -2)
561+
assert [] == buffer.window(-3, 0)
562+
563+
564+
def test_window_fail() -> None:
565+
"""Test the window function with invalid indices."""
566+
buffer = get_orb([0.0, 1.0, 2.0, 3.0, 4.0])
567+
# Go crazy with the indices
568+
with pytest.raises(IndexError):
569+
buffer.window(dt(1), 3)
570+
with pytest.raises(IndexError):
571+
buffer.window(1, dt(3))
572+
with pytest.raises(IndexError):
573+
buffer.window(None, dt(2))
574+
with pytest.raises(IndexError):
575+
buffer.window(dt(2), None)
576+
577+
546578
def test_wrapped_buffer_window() -> None:
547579
"""Test the wrapped buffer window function."""
548580
wbw = OrderedRingBuffer._wrapped_buffer_window # pylint: disable=protected-access

0 commit comments

Comments
 (0)