Skip to content

Commit 70b9f87

Browse files
committed
waveform: Implement __eq__
1 parent 2d6664c commit 70b9f87

File tree

4 files changed

+204
-1
lines changed

4 files changed

+204
-1
lines changed

src/nitypes/waveform/_analog_waveform.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,3 +758,14 @@ def _increase_capacity(self, amount: int) -> None:
758758
new_capacity = self._start_index + self._sample_count + amount
759759
if new_capacity > self.capacity:
760760
self.capacity = new_capacity
761+
762+
def __eq__(self, other: object) -> bool: # noqa: D105 - Missing docstring in magic method
763+
if not isinstance(other, self.__class__):
764+
return NotImplemented
765+
return (
766+
self.dtype == other.dtype
767+
and np.array_equal(self.raw_data, other.raw_data)
768+
and self._extended_properties == other._extended_properties
769+
and self._base_timing == other._base_timing
770+
and self._scale_mode == other._scale_mode
771+
)

src/nitypes/waveform/_scaling/_linear.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ def _transform_data(self, data: npt.NDArray[_ScalarType]) -> npt.NDArray[_Scalar
4444
# npt.NDArray[np.float32] with a float promotes dtype to Any or np.float64
4545
return data * self._gain + self._offset # type: ignore[operator,no-any-return]
4646

47+
def __eq__(self, other: object) -> bool: # noqa: D105 - Missing docstring in magic method
48+
if not isinstance(other, self.__class__):
49+
return NotImplemented
50+
return self._gain == other._gain and self._offset == other._offset
51+
4752
def __repr__( # noqa: D105 - Missing docstring in magic method (auto-generated noqa)
4853
self,
4954
) -> str:

src/nitypes/waveform/_scaling/_none.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ class NoneScaleMode(ScaleMode):
1313
def _transform_data(self, data: npt.NDArray[_ScalarType]) -> npt.NDArray[_ScalarType]:
1414
return data
1515

16+
def __eq__(self, other: object) -> bool: # noqa: D105 - Missing docstring in magic method
17+
if not isinstance(other, self.__class__):
18+
return NotImplemented
19+
return True
20+
1621
def __repr__( # noqa: D105 - Missing docstring in magic method (auto-generated noqa)
1722
self,
1823
) -> str:

tests/unit/waveform/test_analog_waveform.py

Lines changed: 183 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import datetime as dt
55
import itertools
66
import weakref
7-
from typing import Any, SupportsIndex
7+
from typing import Any, SupportsIndex, TypeVar
88

99
import hightime as ht
1010
import numpy as np
@@ -15,6 +15,7 @@
1515
from nitypes.waveform import (
1616
NO_SCALING,
1717
AnalogWaveform,
18+
ExtendedPropertyValue,
1819
LinearScaleMode,
1920
NoneScaleMode,
2021
PrecisionTiming,
@@ -1378,3 +1379,184 @@ def test___regular_waveform_and_irregular_waveform_list___append___raises_runtim
13781379
assert list(waveform.raw_data) == [0, 1, 2]
13791380
assert waveform.timing.sample_interval_mode == SampleIntervalMode.REGULAR
13801381
assert waveform.timing.sample_interval == dt.timedelta(milliseconds=1)
1382+
1383+
1384+
###############################################################################
1385+
# magic methods
1386+
###############################################################################
1387+
_ScalarType = TypeVar("_ScalarType", bound=np.generic)
1388+
1389+
1390+
def _with_timing(
1391+
waveform: AnalogWaveform[_ScalarType], timing: Timing
1392+
) -> AnalogWaveform[_ScalarType]:
1393+
waveform.timing = timing
1394+
return waveform
1395+
1396+
1397+
def _with_precision_timing(
1398+
waveform: AnalogWaveform[_ScalarType], precision_timing: PrecisionTiming
1399+
) -> AnalogWaveform[_ScalarType]:
1400+
waveform.precision_timing = precision_timing
1401+
return waveform
1402+
1403+
1404+
def _with_extended_properties(
1405+
waveform: AnalogWaveform[_ScalarType], extended_properties: dict[str, ExtendedPropertyValue]
1406+
) -> AnalogWaveform[_ScalarType]:
1407+
waveform.extended_properties.update(extended_properties)
1408+
return waveform
1409+
1410+
1411+
def _with_scale_mode(
1412+
waveform: AnalogWaveform[_ScalarType], scale_mode: ScaleMode
1413+
) -> AnalogWaveform[_ScalarType]:
1414+
waveform.scale_mode = scale_mode
1415+
return waveform
1416+
1417+
1418+
@pytest.mark.parametrize(
1419+
"left, right",
1420+
[
1421+
(AnalogWaveform(), AnalogWaveform()),
1422+
(AnalogWaveform(10), AnalogWaveform(10)),
1423+
(AnalogWaveform(10, np.float64), AnalogWaveform(10, np.float64)),
1424+
(AnalogWaveform(10, np.int32), AnalogWaveform(10, np.int32)),
1425+
(
1426+
AnalogWaveform(10, np.int32, start_index=5, capacity=20),
1427+
AnalogWaveform(10, np.int32, start_index=5, capacity=20),
1428+
),
1429+
(
1430+
AnalogWaveform.from_array_1d([1, 2, 3], np.float64),
1431+
AnalogWaveform.from_array_1d([1, 2, 3], np.float64),
1432+
),
1433+
(
1434+
AnalogWaveform.from_array_1d([1, 2, 3], np.int32),
1435+
AnalogWaveform.from_array_1d([1, 2, 3], np.int32),
1436+
),
1437+
(
1438+
_with_timing(
1439+
AnalogWaveform(), Timing.create_with_regular_interval(dt.timedelta(milliseconds=1))
1440+
),
1441+
_with_timing(
1442+
AnalogWaveform(), Timing.create_with_regular_interval(dt.timedelta(milliseconds=1))
1443+
),
1444+
),
1445+
(
1446+
_with_precision_timing(
1447+
AnalogWaveform(),
1448+
PrecisionTiming.create_with_regular_interval(ht.timedelta(milliseconds=1)),
1449+
),
1450+
_with_precision_timing(
1451+
AnalogWaveform(),
1452+
PrecisionTiming.create_with_regular_interval(ht.timedelta(milliseconds=1)),
1453+
),
1454+
),
1455+
(
1456+
_with_extended_properties(
1457+
AnalogWaveform(), {"NI_ChannelName": "Dev1/ai0", "NI_UnitDescription": "Volts"}
1458+
),
1459+
_with_extended_properties(
1460+
AnalogWaveform(), {"NI_ChannelName": "Dev1/ai0", "NI_UnitDescription": "Volts"}
1461+
),
1462+
),
1463+
(
1464+
_with_scale_mode(AnalogWaveform(), LinearScaleMode(2.0, 1.0)),
1465+
_with_scale_mode(AnalogWaveform(), LinearScaleMode(2.0, 1.0)),
1466+
),
1467+
# start_index and capacity may differ as long as raw_data and sample_count are the same.
1468+
(
1469+
AnalogWaveform(10, np.int32, start_index=5, capacity=20),
1470+
AnalogWaveform(10, np.int32, start_index=10, capacity=25),
1471+
),
1472+
(
1473+
AnalogWaveform.from_array_1d(
1474+
[0, 0, 1, 2, 3, 4, 5, 0], np.int32, start_index=2, sample_count=5
1475+
),
1476+
AnalogWaveform.from_array_1d(
1477+
[0, 1, 2, 3, 4, 5, 0, 0, 0], np.int32, start_index=1, sample_count=5
1478+
),
1479+
),
1480+
],
1481+
)
1482+
def test___same_value___equality___equal(
1483+
left: AnalogWaveform[Any], right: AnalogWaveform[Any]
1484+
) -> None:
1485+
assert left == right
1486+
assert not (left != right)
1487+
1488+
1489+
@pytest.mark.parametrize(
1490+
"left, right",
1491+
[
1492+
(AnalogWaveform(), AnalogWaveform(10)),
1493+
(AnalogWaveform(10), AnalogWaveform(11)),
1494+
(AnalogWaveform(10, np.float64), AnalogWaveform(10, np.int32)),
1495+
(
1496+
AnalogWaveform(15, np.int32, start_index=5, capacity=20),
1497+
AnalogWaveform(10, np.int32, start_index=5, capacity=20),
1498+
),
1499+
(
1500+
AnalogWaveform.from_array_1d([1, 4, 3], np.float64),
1501+
AnalogWaveform.from_array_1d([1, 2, 3], np.float64),
1502+
),
1503+
(
1504+
AnalogWaveform.from_array_1d([1, 2, 3], np.int32),
1505+
AnalogWaveform.from_array_1d([1, 2, 3], np.float64),
1506+
),
1507+
(
1508+
_with_timing(
1509+
AnalogWaveform(), Timing.create_with_regular_interval(dt.timedelta(milliseconds=1))
1510+
),
1511+
_with_timing(
1512+
AnalogWaveform(), Timing.create_with_regular_interval(dt.timedelta(milliseconds=2))
1513+
),
1514+
),
1515+
(
1516+
_with_precision_timing(
1517+
AnalogWaveform(),
1518+
PrecisionTiming.create_with_regular_interval(ht.timedelta(milliseconds=1)),
1519+
),
1520+
_with_precision_timing(
1521+
AnalogWaveform(),
1522+
PrecisionTiming.create_with_regular_interval(ht.timedelta(milliseconds=2)),
1523+
),
1524+
),
1525+
(
1526+
_with_extended_properties(
1527+
AnalogWaveform(), {"NI_ChannelName": "Dev1/ai0", "NI_UnitDescription": "Volts"}
1528+
),
1529+
_with_extended_properties(
1530+
AnalogWaveform(), {"NI_ChannelName": "Dev1/ai0", "NI_UnitDescription": "Amps"}
1531+
),
1532+
),
1533+
(
1534+
_with_scale_mode(AnalogWaveform(), LinearScaleMode(2.0, 1.0)),
1535+
_with_scale_mode(AnalogWaveform(), LinearScaleMode(2.0, 1.1)),
1536+
),
1537+
# __eq__ does not convert timing, even if the values are equivalent.
1538+
(
1539+
_with_timing(
1540+
AnalogWaveform(), Timing.create_with_regular_interval(dt.timedelta(milliseconds=1))
1541+
),
1542+
_with_precision_timing(
1543+
AnalogWaveform(),
1544+
PrecisionTiming.create_with_regular_interval(ht.timedelta(milliseconds=1)),
1545+
),
1546+
),
1547+
(
1548+
_with_precision_timing(
1549+
AnalogWaveform(),
1550+
PrecisionTiming.create_with_regular_interval(ht.timedelta(milliseconds=1)),
1551+
),
1552+
_with_timing(
1553+
AnalogWaveform(), Timing.create_with_regular_interval(dt.timedelta(milliseconds=1))
1554+
),
1555+
),
1556+
],
1557+
)
1558+
def test___different_value___equality___not_equal(
1559+
left: AnalogWaveform[Any], right: AnalogWaveform[Any]
1560+
) -> None:
1561+
assert not (left == right)
1562+
assert left != right

0 commit comments

Comments
 (0)