Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions src/lgdo/types/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

import logging
from collections.abc import Iterator
from collections.abc import Collection, Iterator
from typing import Any

import awkward as ak
Expand Down Expand Up @@ -126,19 +126,27 @@ def trim_capacity(self) -> None:
"Set capacity to be minimum needed to support Array size"
self.reserve_capacity(np.prod(self.shape))

def resize(self, new_size: int, trim=False) -> None:
def resize(self, new_size: int | Collection[int], trim=False) -> None:
"""Set size of Array in rows. Only change capacity if it must be
increased to accommodate new rows; in this case double capacity.
If trim is True, capacity will be set to match size."""
If trim is True, capacity will be set to match size. If new_size
is an int, do not change size of inner dimensions.

self._size = new_size
If new_size is a collection, internal memory will be re-allocated, so
this should be done only rarely!"""

if trim and new_size != self.get_capacity:
self.reserve_capacity(new_size)
if isinstance(new_size, Collection):
self._size = new_size[0]
self._nda.resize(new_size)
else:
self._size = new_size

if trim and new_size != self.get_capacity:
self.reserve_capacity(new_size)

# If capacity is not big enough, set to next power of 2 big enough
if new_size > self.get_capacity():
self.reserve_capacity(int(2 ** (np.ceil(np.log2(new_size)))))
# If capacity is not big enough, set to next power of 2 big enough
if new_size > self.get_capacity():
self.reserve_capacity(int(2 ** (np.ceil(np.log2(new_size)))))

def append(self, value: np.ndarray) -> None:
"Append value to end of array (with copy)"
Expand Down
40 changes: 19 additions & 21 deletions src/lgdo/types/waveformtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,25 +112,22 @@ def __init__(
if not isinstance(t0, Array):
shape = (size,)
t0_dtype = t0.dtype if hasattr(t0, "dtype") else np.float32
nda = (
t0 if isinstance(t0, np.ndarray) else np.full(shape, t0, dtype=t0_dtype)
)
if nda.shape != shape:
nda.resize(shape, refcheck=True)
t0 = Array(nda=nda)
if isinstance(t0, np.ndarray):
t0 = Array(nda=t0, shape=shape, dtype=t0_dtype)
else:
t0 = Array(fill_val=t0, shape=shape, dtype=t0_dtype)

if t0_units is not None:
t0.attrs["units"] = f"{t0_units}"

if not isinstance(dt, Array):
shape = (size,)
dt_dtype = dt.dtype if hasattr(dt, "dtype") else np.float32
nda = (
dt if isinstance(dt, np.ndarray) else np.full(shape, dt, dtype=dt_dtype)
)
if nda.shape != shape:
nda.resize(shape, refcheck=True)
dt = Array(nda=nda)
if isinstance(dt, np.ndarray):
dt = Array(nda=dt, shape=shape, dtype=dt_dtype)
else:
dt = Array(fill_val=dt, shape=shape, dtype=dt_dtype)

if dt_units is not None:
dt.attrs["units"] = f"{dt_units}"

Expand Down Expand Up @@ -174,14 +171,15 @@ def __init__(
if hasattr(values, "dtype")
else np.dtype(np.float64)
)
nda = (
values
if isinstance(values, np.ndarray)
else np.zeros(shape, dtype=dtype)
)
if nda.shape != shape:
nda.resize(shape, refcheck=True)
values = ArrayOfEqualSizedArrays(dims=(1, 1), nda=nda)
if isinstance(values, np.ndarray):
values = ArrayOfEqualSizedArrays(
dims=(1, 1), nda=values, shape=shape, dtype=dtype
)
else:
values = ArrayOfEqualSizedArrays(
dims=(1, 1), fill_val=0, shape=shape, dtype=dtype
)

if values_units is not None:
values.attrs["units"] = f"{values_units}"

Expand Down Expand Up @@ -215,7 +213,7 @@ def wf_len(self, wf_len) -> None:
return
shape = self.values.nda.shape
shape = (shape[0], wf_len)
self.values.nda.resize(shape, refcheck=True)
self.values.resize(shape)

def resize_wf_len(self, new_len: int) -> None:
"""Alias for `wf_len.setter`, for when we want to make it clear in
Expand Down
13 changes: 13 additions & 0 deletions tests/types/test_waveformtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ def test_init():
assert (wft.values.nda == np.zeros(shape=(10, 1000))).all()
assert wft.values.nda.dtype == np.float64

wft = WaveformTable(
size=10, dt=np.zeros(5), t0=np.zeros(5), values=np.zeros((5, 50))
)
assert (wft.t0.nda == np.zeros(10)).all()
assert (wft.dt.nda == np.zeros(10)).all()
assert isinstance(wft.values, lgdo.ArrayOfEqualSizedArrays)
assert (wft.values.nda == np.zeros(shape=(10, 50))).all()
assert wft.values.nda.dtype == np.float64

wft = WaveformTable(
values=lgdo.ArrayOfEqualSizedArrays(shape=(10, 1000), fill_val=69)
)
Expand Down Expand Up @@ -85,3 +94,7 @@ def test_init():

wft = WaveformTable(t0=[1, 1, 1], dt=[2, 2, 2], wf_len=1000, dtype=np.float32)
assert wft.values.nda.dtype == np.float32

wft = WaveformTable(10, wf_len=20)
wft.wf_len = 30
assert wft.wf_len == 30