Skip to content

Commit 758924a

Browse files
authored
waveform: Implement load data (#15)
1 parent 67df337 commit 758924a

File tree

3 files changed

+243
-14
lines changed

3 files changed

+243
-14
lines changed

src/nitypes/waveform/_analog_waveform.py

Lines changed: 61 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
from nitypes._arguments import arg_to_uint, validate_dtype, validate_unsupported_arg
1414
from nitypes._exceptions import invalid_arg_type, invalid_array_ndim
1515
from nitypes._typing import Self, TypeAlias
16+
from nitypes.waveform._exceptions import (
17+
input_array_data_type_mismatch,
18+
input_waveform_data_type_mismatch,
19+
)
1620
from nitypes.waveform._extended_properties import (
1721
CHANNEL_NAME,
1822
UNIT_DESCRIPTION,
@@ -800,16 +804,9 @@ def _append_array(
800804
timestamps: Sequence[dt.datetime] | Sequence[ht.datetime] | None = None,
801805
) -> None:
802806
if array.dtype != self.dtype:
803-
raise TypeError(
804-
"The data type of the input array must match the waveform data type.\n\n"
805-
f"Input array data type: {array.dtype}\n"
806-
f"Waveform data type: {self.dtype}"
807-
)
807+
raise input_array_data_type_mismatch(array.dtype, self.dtype)
808808
if array.ndim != 1:
809-
raise ValueError(
810-
"The input array must be a one-dimensional array.\n\n"
811-
f"Number of dimensions: {array.ndim}"
812-
)
809+
raise invalid_array_ndim("input array", "one-dimensional array", array.ndim)
813810
if timestamps is not None and len(array) != len(timestamps):
814811
raise ValueError(
815812
"The number of irregular timestamps must be equal to the input array length.\n\n"
@@ -832,11 +829,7 @@ def _append_waveform(self, waveform: AnalogWaveform[_ScalarType_co]) -> None:
832829
def _append_waveforms(self, waveforms: Sequence[AnalogWaveform[_ScalarType_co]]) -> None:
833830
for waveform in waveforms:
834831
if waveform.dtype != self.dtype:
835-
raise TypeError(
836-
"The data type of the input waveform must match the waveform data type.\n\n"
837-
f"Input waveform data type: {waveform.dtype}\n"
838-
f"Waveform data type: {self.dtype}"
839-
)
832+
raise input_waveform_data_type_mismatch(waveform.dtype, self.dtype)
840833
if waveform._scale_mode != self._scale_mode:
841834
warnings.warn(scale_mode_mismatch())
842835

@@ -859,6 +852,60 @@ def _increase_capacity(self, amount: int) -> None:
859852
if new_capacity > self.capacity:
860853
self.capacity = new_capacity
861854

855+
def load_data(
856+
self,
857+
array: npt.NDArray[_ScalarType_co],
858+
*,
859+
copy: bool = True,
860+
start_index: SupportsIndex | None = 0,
861+
sample_count: SupportsIndex | None = None,
862+
) -> None:
863+
"""Load new data into an existing waveform.
864+
865+
Args:
866+
array: A NumPy array containing the data to load.
867+
copy: Specifies whether to copy the array or save a reference to it.
868+
start_index: The sample index at which the analog waveform data begins.
869+
sample_count: The number of samples in the analog waveform.
870+
"""
871+
if isinstance(array, np.ndarray):
872+
self._load_array(array, copy=copy, start_index=start_index, sample_count=sample_count)
873+
else:
874+
raise invalid_arg_type("input array", "array", array)
875+
876+
def _load_array(
877+
self,
878+
array: npt.NDArray[_ScalarType_co],
879+
*,
880+
copy: bool = True,
881+
start_index: SupportsIndex | None = 0,
882+
sample_count: SupportsIndex | None = None,
883+
) -> None:
884+
if array.dtype != self.dtype:
885+
raise input_array_data_type_mismatch(array.dtype, self.dtype)
886+
if array.ndim != 1:
887+
raise invalid_array_ndim("input array", "one-dimensional array", array.ndim)
888+
if self._timing._timestamps is not None and len(array) != len(self._timing._timestamps):
889+
raise ValueError(
890+
"The input array length must be equal to the number of irregular timestamps.\n\n"
891+
f"Array length: {len(array)}\n"
892+
f"Number of timestamps: {len(self._timing._timestamps)}"
893+
)
894+
895+
start_index = arg_to_uint("start index", start_index, 0)
896+
sample_count = arg_to_uint("sample count", sample_count, len(array) - start_index)
897+
898+
if copy:
899+
if sample_count > len(self._data):
900+
self.capacity = sample_count
901+
self._data[0:sample_count] = array[start_index : start_index + sample_count]
902+
self._start_index = 0
903+
self._sample_count = sample_count
904+
else:
905+
self._data = array
906+
self._start_index = start_index
907+
self._sample_count = sample_count
908+
862909
def __eq__(self, value: object, /) -> bool:
863910
"""Return self==value."""
864911
if not isinstance(value, self.__class__):

src/nitypes/waveform/_exceptions.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,24 @@ class TimingMismatchError(RuntimeError):
77
pass
88

99

10+
def input_array_data_type_mismatch(input_dtype: object, waveform_dtype: object) -> TypeError:
11+
"""Create a TypeError for an input array data type mismatch."""
12+
return TypeError(
13+
"The data type of the input array must match the waveform data type.\n\n"
14+
f"Input array data type: {input_dtype}\n"
15+
f"Waveform data type: {waveform_dtype}"
16+
)
17+
18+
19+
def input_waveform_data_type_mismatch(input_dtype: object, waveform_dtype: object) -> TypeError:
20+
"""Create a TypeError for an input waveform data type mismatch."""
21+
return TypeError(
22+
"The data type of the input waveform must match the waveform data type.\n\n"
23+
f"Input waveform data type: {input_dtype}\n"
24+
f"Waveform data type: {waveform_dtype}"
25+
)
26+
27+
1028
def no_timestamp_information() -> RuntimeError:
1129
"""Create a RuntimeError for waveform timing with no timestamp information."""
1230
return RuntimeError(

tests/unit/waveform/test_analog_waveform.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,6 +1421,170 @@ def test___regular_waveform_and_irregular_waveform_list___append___raises_runtim
14211421
assert waveform.timing.sample_interval == dt.timedelta(milliseconds=1)
14221422

14231423

1424+
###############################################################################
1425+
# load data
1426+
###############################################################################
1427+
def test___empty_ndarray___load_data___clears_data() -> None:
1428+
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32)
1429+
array = np.array([], np.int32)
1430+
1431+
waveform.load_data(array)
1432+
1433+
assert list(waveform.raw_data) == []
1434+
1435+
1436+
def test___int32_ndarray___load_data___overwrites_data() -> None:
1437+
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32)
1438+
array = np.array([3, 4, 5], np.int32)
1439+
1440+
waveform.load_data(array)
1441+
1442+
assert list(waveform.raw_data) == [3, 4, 5]
1443+
1444+
1445+
def test___float64_ndarray___load_data___overwrites_data() -> None:
1446+
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.float64)
1447+
array = np.array([3, 4, 5], np.float64)
1448+
1449+
waveform.load_data(array)
1450+
1451+
assert list(waveform.raw_data) == [3, 4, 5]
1452+
1453+
1454+
def test___ndarray_with_mismatched_dtype___load_data___raises_type_error() -> None:
1455+
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.float64)
1456+
array = np.array([3, 4, 5], np.int32)
1457+
1458+
with pytest.raises(TypeError) as exc:
1459+
waveform.load_data(array) # type: ignore[arg-type]
1460+
1461+
assert exc.value.args[0].startswith(
1462+
"The data type of the input array must match the waveform data type."
1463+
)
1464+
1465+
1466+
def test___ndarray_2d___load_data___raises_value_error() -> None:
1467+
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.float64)
1468+
array = np.array([[3, 4, 5], [6, 7, 8]], np.float64)
1469+
1470+
with pytest.raises(ValueError) as exc:
1471+
waveform.load_data(array)
1472+
1473+
assert exc.value.args[0].startswith("The input array must be a one-dimensional array.")
1474+
1475+
1476+
def test___smaller_ndarray___load_data___preserves_capacity() -> None:
1477+
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32)
1478+
array = np.array([3], np.int32)
1479+
1480+
waveform.load_data(array)
1481+
1482+
assert list(waveform.raw_data) == [3]
1483+
assert waveform.capacity == 3
1484+
1485+
1486+
def test___larger_ndarray___load_data___grows_capacity() -> None:
1487+
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32)
1488+
array = np.array([3, 4, 5, 6], np.int32)
1489+
1490+
waveform.load_data(array)
1491+
1492+
assert list(waveform.raw_data) == [3, 4, 5, 6]
1493+
assert waveform.capacity == 4
1494+
1495+
1496+
def test___waveform_with_start_index___load_data___clears_start_index() -> None:
1497+
waveform = AnalogWaveform.from_array_1d(
1498+
np.array([0, 1, 2], np.int32), np.int32, copy=False, start_index=1, sample_count=1
1499+
)
1500+
assert waveform._start_index == 1
1501+
array = np.array([3], np.int32)
1502+
1503+
waveform.load_data(array)
1504+
1505+
assert list(waveform.raw_data) == [3]
1506+
assert waveform._start_index == 0
1507+
1508+
1509+
def test___ndarray_subset___load_data___overwrites_data() -> None:
1510+
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32)
1511+
array = np.array([3, 4, 5], np.int32)
1512+
1513+
waveform.load_data(array, start_index=1, sample_count=1)
1514+
1515+
assert list(waveform.raw_data) == [4]
1516+
assert waveform._start_index == 0
1517+
assert waveform.capacity == 3
1518+
1519+
1520+
def test___smaller_ndarray_no_copy___load_data___takes_ownership_of_array() -> None:
1521+
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32)
1522+
array = np.array([3], np.int32)
1523+
1524+
waveform.load_data(array, copy=False)
1525+
1526+
assert list(waveform.raw_data) == [3]
1527+
assert waveform._data is array
1528+
1529+
1530+
def test___larger_ndarray_no_copy___load_data___takes_ownership_of_array() -> None:
1531+
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32)
1532+
array = np.array([3, 4, 5, 6], np.int32)
1533+
1534+
waveform.load_data(array, copy=False)
1535+
1536+
assert list(waveform.raw_data) == [3, 4, 5, 6]
1537+
assert waveform._data is array
1538+
1539+
1540+
def test___ndarray_subset_no_copy___load_data___takes_ownership_of_array_subset() -> None:
1541+
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32)
1542+
array = np.array([3, 4, 5, 6], np.int32)
1543+
1544+
waveform.load_data(array, copy=False, start_index=1, sample_count=2)
1545+
1546+
assert list(waveform.raw_data) == [4, 5]
1547+
assert waveform._data is array
1548+
1549+
1550+
def test___irregular_waveform_and_int32_ndarray_with_timestamps___load_data___overwrites_data_but_not_timestamps() -> (
1551+
None
1552+
):
1553+
start_time = dt.datetime.now(dt.timezone.utc)
1554+
waveform_offsets = [dt.timedelta(0), dt.timedelta(1), dt.timedelta(2)]
1555+
waveform_timestamps = [start_time + offset for offset in waveform_offsets]
1556+
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32)
1557+
waveform.timing = Timing.create_with_irregular_interval(waveform_timestamps)
1558+
array = np.array([3, 4, 5], np.int32)
1559+
1560+
waveform.load_data(array)
1561+
1562+
assert list(waveform.raw_data) == [3, 4, 5]
1563+
assert waveform.timing.sample_interval_mode == SampleIntervalMode.IRREGULAR
1564+
assert waveform.timing._timestamps == waveform_timestamps
1565+
1566+
1567+
def test___irregular_waveform_and_int32_ndarray_with_wrong_sample_count___load_data___raises_value_error_and_does_not_overwrite_data() -> (
1568+
None
1569+
):
1570+
start_time = dt.datetime.now(dt.timezone.utc)
1571+
waveform_offsets = [dt.timedelta(0), dt.timedelta(1), dt.timedelta(2)]
1572+
waveform_timestamps = [start_time + offset for offset in waveform_offsets]
1573+
waveform = AnalogWaveform.from_array_1d([0, 1, 2], np.int32)
1574+
waveform.timing = Timing.create_with_irregular_interval(waveform_timestamps)
1575+
array = np.array([3, 4], np.int32)
1576+
1577+
with pytest.raises(ValueError) as exc:
1578+
waveform.load_data(array)
1579+
1580+
assert exc.value.args[0].startswith(
1581+
"The input array length must be equal to the number of irregular timestamps."
1582+
)
1583+
assert list(waveform.raw_data) == [0, 1, 2]
1584+
assert waveform.timing.sample_interval_mode == SampleIntervalMode.IRREGULAR
1585+
assert waveform.timing._timestamps == waveform_timestamps
1586+
1587+
14241588
###############################################################################
14251589
# magic methods
14261590
###############################################################################

0 commit comments

Comments
 (0)