Skip to content

Commit 20bd659

Browse files
committed
Add separate method for extracting wrapped buffers
This just moves the logic into a separate function to allow better testing of the logic. Bug found in the logic will be fixed in later commit. Signed-off-by: cwasicki <[email protected]>
1 parent f621807 commit 20bd659

File tree

2 files changed

+95
-13
lines changed

2 files changed

+95
-13
lines changed

src/frequenz/sdk/timeseries/_ringbuffer/buffer.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -295,22 +295,43 @@ def window(
295295
start_index = self.datetime_to_index(start)
296296
end_index = self.datetime_to_index(end)
297297

298+
return self._wrapped_buffer_window(
299+
self._buffer, start_index, end_index, force_copy
300+
)
301+
302+
@staticmethod
303+
def _wrapped_buffer_window(
304+
buffer: FloatArray, start_pos: int, end_pos: int, force_copy: bool = True
305+
) -> FloatArray:
306+
"""Get a wrapped window from the given buffer.
307+
308+
If start_pos == end_pos, the full wrapped buffer is returned starting at start_pos.
309+
310+
Copies can only be avoided for numpy arrays and when the window is not wrapped.
311+
Lists of floats are always copies.
312+
313+
Args:
314+
buffer: The buffer to get the window from.
315+
start_pos: The start position of the window in the buffer.
316+
end_pos: The end position of the window in the buffer (exclusive).
317+
force_copy: If True, will always create a copy of the data.
318+
319+
Returns:
320+
The requested window.
321+
"""
298322
# Requested window wraps around the ends
299-
if start_index >= end_index:
300-
if end_index > 0:
301-
if isinstance(self._buffer, list):
302-
return self._buffer[start_index:] + self._buffer[0:end_index]
303-
if isinstance(self._buffer, np.ndarray):
304-
return np.concatenate(
305-
(self._buffer[start_index:], self._buffer[0:end_index])
306-
)
307-
assert False, f"Unknown _buffer type: {type(self._buffer)}"
308-
return self._buffer[start_index:]
323+
if start_pos >= end_pos:
324+
if end_pos > 0:
325+
if isinstance(buffer, list):
326+
return buffer[start_pos:] + buffer[0:end_pos]
327+
if isinstance(buffer, np.ndarray):
328+
return np.concatenate((buffer[start_pos:], buffer[0:end_pos]))
329+
assert False, f"Unknown buffer type: {type(buffer)}"
330+
return buffer[start_pos:]
309331

310332
if force_copy:
311-
return deepcopy(self[start_index:end_index])
312-
313-
return self[start_index:end_index]
333+
return deepcopy(buffer[start_pos:end_pos])
334+
return buffer[start_pos:end_pos]
314335

315336
def is_missing(self, timestamp: datetime) -> bool:
316337
"""Check if the given timestamp falls within a gap.

tests/timeseries/test_ringbuffer.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,3 +522,64 @@ def test_window() -> None:
522522
assert [0, 1, 2] == list(win)
523523
win = buffer.window(dt(0), dt(3), force_copy=False)
524524
assert [0, 1, 2] == list(win)
525+
526+
527+
def test_wrapped_buffer_window() -> None:
528+
"""Test the wrapped buffer window function."""
529+
wbw = OrderedRingBuffer._wrapped_buffer_window # pylint: disable=protected-access
530+
531+
#
532+
# Tests for list buffer
533+
#
534+
buffer = [0.0, 1.0, 2.0, 3.0, 4.0]
535+
# start = end
536+
assert [0, 1, 2, 3, 4] == wbw(buffer, 0, 0, force_copy=False)
537+
assert [4, 0, 1, 2, 3] == wbw(buffer, 4, 4, force_copy=False)
538+
# start < end
539+
assert [0] == wbw(buffer, 0, 1, force_copy=False)
540+
assert [0, 1, 2, 3, 4] == wbw(buffer, 0, 5, force_copy=False)
541+
# start > end, end = 0
542+
assert [4] == wbw(buffer, 4, 0, force_copy=False)
543+
# start > end, end > 0
544+
assert [4, 0, 1] == wbw(buffer, 4, 2, force_copy=False)
545+
546+
# Lists are always shallow copies
547+
res_copy = wbw(buffer, 0, 5, force_copy=False)
548+
assert [0, 1, 2, 3, 4] == res_copy
549+
buffer[0] = 9
550+
assert [0, 1, 2, 3, 4] == res_copy
551+
552+
#
553+
# Tests for array buffer
554+
#
555+
buffer = np.array([0, 1, 2, 3, 4]) # type: ignore
556+
# start = end
557+
assert [0, 1, 2, 3, 4] == list(wbw(buffer, 0, 0, force_copy=False))
558+
assert [4, 0, 1, 2, 3] == list(wbw(buffer, 4, 4, force_copy=False))
559+
# start < end
560+
assert [0] == list(wbw(buffer, 0, 1, force_copy=False))
561+
assert [0, 1, 2, 3, 4] == list(wbw(buffer, 0, 5, force_copy=False))
562+
# start > end, end = 0
563+
assert [4] == list(wbw(buffer, 4, 0, force_copy=False))
564+
# start > end, end > 0
565+
assert [4, 0, 1] == list(wbw(buffer, 4, 2, force_copy=False))
566+
567+
# Get a view and a copy before modifying the buffer
568+
res1_view = wbw(buffer, 3, 5, force_copy=False)
569+
res1_copy = wbw(buffer, 3, 5, force_copy=True)
570+
res2_view = wbw(buffer, 3, 0, force_copy=False)
571+
res2_copy = wbw(buffer, 3, 0, force_copy=True)
572+
res3_copy = wbw(buffer, 4, 1, force_copy=False)
573+
assert [3, 4] == list(res1_view)
574+
assert [3, 4] == list(res1_copy)
575+
assert [3, 4] == list(res2_view)
576+
assert [3, 4] == list(res2_copy)
577+
assert [4, 0] == list(res3_copy)
578+
579+
# Modify the buffer and check that only the view is updated
580+
buffer[4] = 9
581+
assert [3, 9] == list(res1_view)
582+
assert [3, 4] == list(res1_copy)
583+
assert [3, 9] == list(res2_view)
584+
# assert [3, 4] == list(res2_copy) #Fails because of a bug
585+
assert [4, 0] == list(res3_copy)

0 commit comments

Comments
 (0)