Skip to content

Commit a007051

Browse files
committed
2 parents 25365ef + e50b355 commit a007051

File tree

3 files changed

+49
-30
lines changed

3 files changed

+49
-30
lines changed

src/lgdo/types/array.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from __future__ import annotations
77

88
import logging
9-
from collections.abc import Iterator
9+
from collections.abc import Collection, Iterator
1010
from typing import Any
1111

1212
import awkward as ak
@@ -126,19 +126,27 @@ def trim_capacity(self) -> None:
126126
"Set capacity to be minimum needed to support Array size"
127127
self.reserve_capacity(np.prod(self.shape))
128128

129-
def resize(self, new_size: int, trim=False) -> None:
129+
def resize(self, new_size: int | Collection[int], trim=False) -> None:
130130
"""Set size of Array in rows. Only change capacity if it must be
131131
increased to accommodate new rows; in this case double capacity.
132-
If trim is True, capacity will be set to match size."""
132+
If trim is True, capacity will be set to match size. If new_size
133+
is an int, do not change size of inner dimensions.
133134
134-
self._size = new_size
135+
If new_size is a collection, internal memory will be re-allocated, so
136+
this should be done only rarely!"""
135137

136-
if trim and new_size != self.get_capacity:
137-
self.reserve_capacity(new_size)
138+
if isinstance(new_size, Collection):
139+
self._size = new_size[0]
140+
self._nda.resize(new_size)
141+
else:
142+
self._size = new_size
143+
144+
if trim and new_size != self.get_capacity:
145+
self.reserve_capacity(new_size)
138146

139-
# If capacity is not big enough, set to next power of 2 big enough
140-
if new_size > self.get_capacity():
141-
self.reserve_capacity(int(2 ** (np.ceil(np.log2(new_size)))))
147+
# If capacity is not big enough, set to next power of 2 big enough
148+
if new_size > self.get_capacity():
149+
self.reserve_capacity(int(2 ** (np.ceil(np.log2(new_size)))))
142150

143151
def append(self, value: np.ndarray) -> None:
144152
"Append value to end of array (with copy)"

src/lgdo/types/waveformtable.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -112,25 +112,22 @@ def __init__(
112112
if not isinstance(t0, Array):
113113
shape = (size,)
114114
t0_dtype = t0.dtype if hasattr(t0, "dtype") else np.float32
115-
nda = (
116-
t0 if isinstance(t0, np.ndarray) else np.full(shape, t0, dtype=t0_dtype)
117-
)
118-
if nda.shape != shape:
119-
nda.resize(shape, refcheck=True)
120-
t0 = Array(nda=nda)
115+
if isinstance(t0, np.ndarray):
116+
t0 = Array(nda=t0, shape=shape, dtype=t0_dtype)
117+
else:
118+
t0 = Array(fill_val=t0, shape=shape, dtype=t0_dtype)
121119

122120
if t0_units is not None:
123121
t0.attrs["units"] = f"{t0_units}"
124122

125123
if not isinstance(dt, Array):
126124
shape = (size,)
127125
dt_dtype = dt.dtype if hasattr(dt, "dtype") else np.float32
128-
nda = (
129-
dt if isinstance(dt, np.ndarray) else np.full(shape, dt, dtype=dt_dtype)
130-
)
131-
if nda.shape != shape:
132-
nda.resize(shape, refcheck=True)
133-
dt = Array(nda=nda)
126+
if isinstance(dt, np.ndarray):
127+
dt = Array(nda=dt, shape=shape, dtype=dt_dtype)
128+
else:
129+
dt = Array(fill_val=dt, shape=shape, dtype=dt_dtype)
130+
134131
if dt_units is not None:
135132
dt.attrs["units"] = f"{dt_units}"
136133

@@ -174,14 +171,15 @@ def __init__(
174171
if hasattr(values, "dtype")
175172
else np.dtype(np.float64)
176173
)
177-
nda = (
178-
values
179-
if isinstance(values, np.ndarray)
180-
else np.zeros(shape, dtype=dtype)
181-
)
182-
if nda.shape != shape:
183-
nda.resize(shape, refcheck=True)
184-
values = ArrayOfEqualSizedArrays(dims=(1, 1), nda=nda)
174+
if isinstance(values, np.ndarray):
175+
values = ArrayOfEqualSizedArrays(
176+
dims=(1, 1), nda=values, shape=shape, dtype=dtype
177+
)
178+
else:
179+
values = ArrayOfEqualSizedArrays(
180+
dims=(1, 1), fill_val=0, shape=shape, dtype=dtype
181+
)
182+
185183
if values_units is not None:
186184
values.attrs["units"] = f"{values_units}"
187185

@@ -215,7 +213,7 @@ def wf_len(self, wf_len) -> None:
215213
return
216214
shape = self.values.nda.shape
217215
shape = (shape[0], wf_len)
218-
self.values.nda.resize(shape, refcheck=True)
216+
self.values.resize(shape)
219217

220218
def resize_wf_len(self, new_len: int) -> None:
221219
"""Alias for `wf_len.setter`, for when we want to make it clear in

tests/types/test_waveformtable.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ def test_init():
2525
assert (wft.values.nda == np.zeros(shape=(10, 1000))).all()
2626
assert wft.values.nda.dtype == np.float64
2727

28+
wft = WaveformTable(
29+
size=10, dt=np.zeros(5), t0=np.zeros(5), values=np.zeros((5, 50))
30+
)
31+
assert (wft.t0.nda == np.zeros(10)).all()
32+
assert (wft.dt.nda == np.zeros(10)).all()
33+
assert isinstance(wft.values, lgdo.ArrayOfEqualSizedArrays)
34+
assert (wft.values.nda == np.zeros(shape=(10, 50))).all()
35+
assert wft.values.nda.dtype == np.float64
36+
2837
wft = WaveformTable(
2938
values=lgdo.ArrayOfEqualSizedArrays(shape=(10, 1000), fill_val=69)
3039
)
@@ -85,3 +94,7 @@ def test_init():
8594

8695
wft = WaveformTable(t0=[1, 1, 1], dt=[2, 2, 2], wf_len=1000, dtype=np.float32)
8796
assert wft.values.nda.dtype == np.float32
97+
98+
wft = WaveformTable(10, wf_len=20)
99+
wft.wf_len = 30
100+
assert wft.wf_len == 30

0 commit comments

Comments
 (0)