diff --git a/src/lgdo/types/array.py b/src/lgdo/types/array.py index 2dc3627b..cdf1436d 100644 --- a/src/lgdo/types/array.py +++ b/src/lgdo/types/array.py @@ -6,7 +6,7 @@ from __future__ import annotations import logging -from collections.abc import Iterator +from collections.abc import Collection, Iterator from typing import Any import awkward as ak @@ -126,19 +126,27 @@ def trim_capacity(self) -> None: "Set capacity to be minimum needed to support Array size" self.reserve_capacity(np.prod(self.shape)) - def resize(self, new_size: int, trim=False) -> None: + def resize(self, new_size: int | Collection[int], trim=False) -> None: """Set size of Array in rows. Only change capacity if it must be increased to accommodate new rows; in this case double capacity. - If trim is True, capacity will be set to match size.""" + If trim is True, capacity will be set to match size. If new_size + is an int, do not change size of inner dimensions. - self._size = new_size + If new_size is a collection, internal memory will be re-allocated, so + this should be done only rarely!""" - if trim and new_size != self.get_capacity: - self.reserve_capacity(new_size) + if isinstance(new_size, Collection): + self._size = new_size[0] + self._nda.resize(new_size) + else: + self._size = new_size + + if trim and new_size != self.get_capacity: + self.reserve_capacity(new_size) - # If capacity is not big enough, set to next power of 2 big enough - if new_size > self.get_capacity(): - self.reserve_capacity(int(2 ** (np.ceil(np.log2(new_size))))) + # If capacity is not big enough, set to next power of 2 big enough + if new_size > self.get_capacity(): + self.reserve_capacity(int(2 ** (np.ceil(np.log2(new_size))))) def append(self, value: np.ndarray) -> None: "Append value to end of array (with copy)" diff --git a/src/lgdo/types/waveformtable.py b/src/lgdo/types/waveformtable.py index 4fffeeb4..ffc3b5ca 100644 --- a/src/lgdo/types/waveformtable.py +++ b/src/lgdo/types/waveformtable.py @@ -112,12 +112,10 @@ def __init__( if not isinstance(t0, Array): shape = (size,) t0_dtype = t0.dtype if hasattr(t0, "dtype") else np.float32 - nda = ( - t0 if isinstance(t0, np.ndarray) else np.full(shape, t0, dtype=t0_dtype) - ) - if nda.shape != shape: - nda.resize(shape, refcheck=True) - t0 = Array(nda=nda) + if isinstance(t0, np.ndarray): + t0 = Array(nda=t0, shape=shape, dtype=t0_dtype) + else: + t0 = Array(fill_val=t0, shape=shape, dtype=t0_dtype) if t0_units is not None: t0.attrs["units"] = f"{t0_units}" @@ -125,12 +123,11 @@ def __init__( if not isinstance(dt, Array): shape = (size,) dt_dtype = dt.dtype if hasattr(dt, "dtype") else np.float32 - nda = ( - dt if isinstance(dt, np.ndarray) else np.full(shape, dt, dtype=dt_dtype) - ) - if nda.shape != shape: - nda.resize(shape, refcheck=True) - dt = Array(nda=nda) + if isinstance(dt, np.ndarray): + dt = Array(nda=dt, shape=shape, dtype=dt_dtype) + else: + dt = Array(fill_val=dt, shape=shape, dtype=dt_dtype) + if dt_units is not None: dt.attrs["units"] = f"{dt_units}" @@ -174,14 +171,15 @@ def __init__( if hasattr(values, "dtype") else np.dtype(np.float64) ) - nda = ( - values - if isinstance(values, np.ndarray) - else np.zeros(shape, dtype=dtype) - ) - if nda.shape != shape: - nda.resize(shape, refcheck=True) - values = ArrayOfEqualSizedArrays(dims=(1, 1), nda=nda) + if isinstance(values, np.ndarray): + values = ArrayOfEqualSizedArrays( + dims=(1, 1), nda=values, shape=shape, dtype=dtype + ) + else: + values = ArrayOfEqualSizedArrays( + dims=(1, 1), fill_val=0, shape=shape, dtype=dtype + ) + if values_units is not None: values.attrs["units"] = f"{values_units}" @@ -215,7 +213,7 @@ def wf_len(self, wf_len) -> None: return shape = self.values.nda.shape shape = (shape[0], wf_len) - self.values.nda.resize(shape, refcheck=True) + self.values.resize(shape) def resize_wf_len(self, new_len: int) -> None: """Alias for `wf_len.setter`, for when we want to make it clear in diff --git a/tests/types/test_waveformtable.py b/tests/types/test_waveformtable.py index f81e40ce..a9df1566 100644 --- a/tests/types/test_waveformtable.py +++ b/tests/types/test_waveformtable.py @@ -25,6 +25,15 @@ def test_init(): assert (wft.values.nda == np.zeros(shape=(10, 1000))).all() assert wft.values.nda.dtype == np.float64 + wft = WaveformTable( + size=10, dt=np.zeros(5), t0=np.zeros(5), values=np.zeros((5, 50)) + ) + assert (wft.t0.nda == np.zeros(10)).all() + assert (wft.dt.nda == np.zeros(10)).all() + assert isinstance(wft.values, lgdo.ArrayOfEqualSizedArrays) + assert (wft.values.nda == np.zeros(shape=(10, 50))).all() + assert wft.values.nda.dtype == np.float64 + wft = WaveformTable( values=lgdo.ArrayOfEqualSizedArrays(shape=(10, 1000), fill_val=69) ) @@ -85,3 +94,7 @@ def test_init(): wft = WaveformTable(t0=[1, 1, 1], dt=[2, 2, 2], wf_len=1000, dtype=np.float32) assert wft.values.nda.dtype == np.float32 + + wft = WaveformTable(10, wf_len=20) + wft.wf_len = 30 + assert wft.wf_len == 30