Skip to content

Commit 66fab5e

Browse files
committed
Add separate method for extracting wrapped buffers
This just moves the logic into a separate function to allow better testing of the logic. Bug found in the logic will be fixed in later commit.
1 parent ada4221 commit 66fab5e

File tree

2 files changed

+95
-13
lines changed

2 files changed

+95
-13
lines changed

src/frequenz/sdk/timeseries/_ringbuffer/buffer.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -248,22 +248,43 @@ def window(
248248
start_index = self.datetime_to_index(start)
249249
end_index = self.datetime_to_index(end)
250250

251+
return self._wrapped_buffer_window(
252+
self._buffer, start_index, end_index, force_copy
253+
)
254+
255+
@staticmethod
256+
def _wrapped_buffer_window(
257+
buffer: FloatArray, start_pos: int, end_pos: int, force_copy: bool = True
258+
) -> FloatArray:
259+
"""Get a wrapped window from the given buffer.
260+
261+
If start_pos == end_pos, the full wrapped buffer is returned starting at start_pos.
262+
263+
Copies can only be avoided for numpy arrays and when the window is not wrapped.
264+
Lists of floats are always copies.
265+
266+
Args:
267+
buffer: The buffer to get the window from.
268+
start_pos: The start position of the window in the buffer.
269+
end_pos: The end position of the window in the buffer (exclusive).
270+
force_copy: If True, will always create a copy of the data.
271+
272+
Returns:
273+
The requested window.
274+
"""
251275
# Requested window wraps around the ends
252-
if start_index >= end_index:
253-
if end_index > 0:
254-
if isinstance(self._buffer, list):
255-
return self._buffer[start_index:] + self._buffer[0:end_index]
256-
if isinstance(self._buffer, np.ndarray):
257-
return np.concatenate(
258-
(self._buffer[start_index:], self._buffer[0:end_index])
259-
)
260-
assert False, f"Unknown _buffer type: {type(self._buffer)}"
261-
return self._buffer[start_index:]
276+
if start_pos >= end_pos:
277+
if end_pos > 0:
278+
if isinstance(buffer, list):
279+
return buffer[start_pos:] + buffer[0:end_pos]
280+
if isinstance(buffer, np.ndarray):
281+
return np.concatenate((buffer[start_pos:], buffer[0:end_pos]))
282+
assert False, f"Unknown buffer type: {type(buffer)}"
283+
return buffer[start_pos:]
262284

263285
if force_copy:
264-
return deepcopy(self[start_index:end_index])
265-
266-
return self[start_index:end_index]
286+
return deepcopy(buffer[start_pos:end_pos])
287+
return buffer[start_pos:end_pos]
267288

268289
def is_missing(self, timestamp: datetime) -> bool:
269290
"""Check if the given timestamp falls within a gap.

tests/timeseries/test_ringbuffer.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,3 +452,64 @@ def test_window() -> None:
452452
assert [0, 1, 2] == list(win)
453453
win = buffer.window(dt(0), dt(3), force_copy=False)
454454
assert [0, 1, 2] == list(win)
455+
456+
457+
def test_wrapped_buffer_window() -> None:
458+
"""Test the wrapped buffer window function."""
459+
wbw = OrderedRingBuffer._wrapped_buffer_window # pylint: disable=protected-access
460+
461+
#
462+
# Tests for list buffer
463+
#
464+
buffer = [0.0, 1.0, 2.0, 3.0, 4.0]
465+
# start = end
466+
assert [0, 1, 2, 3, 4] == wbw(buffer, 0, 0, force_copy=False)
467+
assert [4, 0, 1, 2, 3] == wbw(buffer, 4, 4, force_copy=False)
468+
# start < end
469+
assert [0] == wbw(buffer, 0, 1, force_copy=False)
470+
assert [0, 1, 2, 3, 4] == wbw(buffer, 0, 5, force_copy=False)
471+
# start > end, end = 0
472+
assert [4] == wbw(buffer, 4, 0, force_copy=False)
473+
# start > end, end > 0
474+
assert [4, 0, 1] == wbw(buffer, 4, 2, force_copy=False)
475+
476+
# Lists are always shallow copies
477+
res_copy = wbw(buffer, 0, 5, force_copy=False)
478+
assert [0, 1, 2, 3, 4] == res_copy
479+
buffer[0] = 9
480+
assert [0, 1, 2, 3, 4] == res_copy
481+
482+
#
483+
# Tests for array buffer
484+
#
485+
buffer = np.array([0, 1, 2, 3, 4]) # type: ignore
486+
# start = end
487+
assert [0, 1, 2, 3, 4] == list(wbw(buffer, 0, 0, force_copy=False))
488+
assert [4, 0, 1, 2, 3] == list(wbw(buffer, 4, 4, force_copy=False))
489+
# start < end
490+
assert [0] == list(wbw(buffer, 0, 1, force_copy=False))
491+
assert [0, 1, 2, 3, 4] == list(wbw(buffer, 0, 5, force_copy=False))
492+
# start > end, end = 0
493+
assert [4] == list(wbw(buffer, 4, 0, force_copy=False))
494+
# start > end, end > 0
495+
assert [4, 0, 1] == list(wbw(buffer, 4, 2, force_copy=False))
496+
497+
# Get a view and a copy before modifying the buffer
498+
res1_view = wbw(buffer, 3, 5, force_copy=False)
499+
res1_copy = wbw(buffer, 3, 5, force_copy=True)
500+
res2_view = wbw(buffer, 3, 0, force_copy=False)
501+
res2_copy = wbw(buffer, 3, 0, force_copy=True)
502+
res3_copy = wbw(buffer, 4, 1, force_copy=False)
503+
assert [3, 4] == list(res1_view)
504+
assert [3, 4] == list(res1_copy)
505+
assert [3, 4] == list(res2_view)
506+
assert [3, 4] == list(res2_copy)
507+
assert [4, 0] == list(res3_copy)
508+
509+
# Modify the buffer and check that only the view is updated
510+
buffer[4] = 9
511+
assert [3, 9] == list(res1_view)
512+
assert [3, 4] == list(res1_copy)
513+
assert [3, 9] == list(res2_view)
514+
# assert [3, 4] == list(res2_copy) #Fails because of a bug
515+
assert [4, 0] == list(res3_copy)

0 commit comments

Comments
 (0)