Skip to content

Commit c86f760

Browse files
authored
Add DataType equal and all_equal helpers (#289)
1 parent 49e799e commit c86f760

File tree

4 files changed

+90
-7
lines changed

4 files changed

+90
-7
lines changed

src/fastcs/datatypes/datatype.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import enum
22
from abc import abstractmethod
3+
from collections.abc import Sequence
34
from dataclasses import dataclass
45
from typing import Any, Generic, TypeVar
56

@@ -28,6 +29,11 @@ class DataType(Generic[DType_T]):
2829
def dtype(self) -> type[DType_T]: # Using property due to lack of Generic ClassVars
2930
raise NotImplementedError()
3031

32+
@property
33+
@abstractmethod
34+
def initial_value(self) -> DType_T:
35+
raise NotImplementedError()
36+
3137
def validate(self, value: Any) -> DType_T:
3238
"""Validate a value against the datatype.
3339
@@ -55,7 +61,32 @@ def validate(self, value: Any) -> DType_T:
5561
except (ValueError, TypeError) as e:
5662
raise ValueError(f"Failed to cast {value} to type {self.dtype}") from e
5763

58-
@property
59-
@abstractmethod
60-
def initial_value(self) -> DType_T:
61-
raise NotImplementedError()
64+
@staticmethod
65+
def equal(value1: DType_T, value2: DType_T) -> bool:
66+
"""Compare two values for equality
67+
68+
Child classes can override this if the underlying type does not implement
69+
``__eq__`` or to define custom logic.
70+
71+
Args:
72+
value1: The first value to compare
73+
value2: The second value to compare
74+
75+
Returns:
76+
`True` if the values are equal
77+
78+
"""
79+
return value1 == value2
80+
81+
@classmethod
82+
def all_equal(cls, values: Sequence[DType_T]) -> bool:
83+
"""Compare a sequence of values for equality
84+
85+
Args:
86+
values: Values to compare
87+
88+
Returns:
89+
`True` if all values are equal, else `False`
90+
91+
"""
92+
return all(cls.equal(values[0], value) for value in values[1:])

src/fastcs/datatypes/table.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,7 @@ def validate(self, value: Any) -> np.ndarray:
3030
)
3131

3232
return _value
33+
34+
@staticmethod
35+
def equal(value1: np.ndarray, value2: np.ndarray) -> bool:
36+
return np.array_equal(value1, value2)

src/fastcs/datatypes/waveform.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,7 @@ def validate(self, value: np.ndarray) -> np.ndarray:
3838
)
3939

4040
return _value
41+
42+
@staticmethod
43+
def equal(value1: np.ndarray, value2: np.ndarray) -> bool:
44+
return np.array_equal(value1, value2)

tests/test_datatypes.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33
import numpy as np
44
import pytest
55

6-
from fastcs.datatypes import DataType, Enum, Float, Int, Waveform
6+
from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, Table, Waveform
77
from fastcs.datatypes._util import numpy_to_fastcs_datatype
8-
from fastcs.datatypes.bool import Bool
9-
from fastcs.datatypes.string import String
108

119

1210
def test_base_validate():
@@ -61,3 +59,49 @@ def test_validate(datatype, init_args, value):
6159
)
6260
def test_numpy_to_fastcs_datatype(numpy_type, fastcs_datatype):
6361
assert fastcs_datatype == numpy_to_fastcs_datatype(numpy_type)
62+
63+
64+
@pytest.mark.parametrize(
65+
"fastcs_datatype, value1, value2, expected",
66+
[
67+
(Int(), 1, 1, True),
68+
(Int(), 1, 2, False),
69+
(Float(), 1.0, 1.0, True),
70+
(Float(), 1.0, 2.0, False),
71+
(Bool(), True, True, True),
72+
(Bool(), True, False, False),
73+
(String(), "foo", "foo", True),
74+
(String(), "foo", "bar", False),
75+
(Waveform(np.int16), np.array([1]), np.array([1]), True),
76+
(Waveform(np.int16), np.array([1]), np.array([2]), False),
77+
(
78+
Table([("int", np.int16), ("bool", np.bool), ("str", np.dtype("S10"))]),
79+
np.array([1, True, "foo"]),
80+
np.array([1, True, "foo"]),
81+
True,
82+
),
83+
(
84+
Table([("int", np.int16), ("bool", np.bool), ("str", np.dtype("S10"))]),
85+
np.array([1, True, "foo"]),
86+
np.array([2, False, "bar"]),
87+
False,
88+
),
89+
],
90+
)
91+
def test_dataset_equal(fastcs_datatype: DataType, value1, value2, expected):
92+
assert fastcs_datatype.equal(value1, value2) is expected
93+
94+
95+
@pytest.mark.parametrize(
96+
"fastcs_datatype, values, expected",
97+
[
98+
(Int(), [1, 1], True),
99+
(Int(), [1, 2], False),
100+
(Float(), [1.0, 1.0], True),
101+
(Float(), [1.0, 2.0], False),
102+
(Bool(), [True, True], True),
103+
(Bool(), [True, False], False),
104+
],
105+
)
106+
def test_dataset_all_equal(fastcs_datatype: DataType, values, expected):
107+
assert fastcs_datatype.all_equal(values) is expected

0 commit comments

Comments
 (0)