Skip to content

Commit 6a21692

Browse files
committed
Merge branch 'separate-array-equal' into add-extended-method
2 parents 0b90af2 + c4714a3 commit 6a21692

File tree

2 files changed

+24
-19
lines changed

2 files changed

+24
-19
lines changed

src/power_grid_model_ds/_core/fancypy.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import numpy as np
1010

11+
from power_grid_model_ds._core.utils.misc import array_equal_with_nan
12+
1113
if TYPE_CHECKING:
1214
from power_grid_model_ds._core.model.arrays.base.array import FancyArray
1315

@@ -44,23 +46,5 @@ def sort(array: "FancyArray", axis=-1, kind=None, order=None) -> "FancyArray":
4446
def array_equal(array1: "FancyArray", array2: "FancyArray", equal_nan: bool = True) -> bool:
4547
"""Return True if two arrays are equal."""
4648
if equal_nan:
47-
return _array_equal_with_nan(array1, array2)
49+
return array_equal_with_nan(array1.data, array2.data)
4850
return np.array_equal(array1.data, array2.data)
49-
50-
51-
def _array_equal_with_nan(array1: "FancyArray", array2: "FancyArray") -> bool:
52-
# np.array_equal does not work with NaN values in structured arrays, so we need to compare column by column.
53-
# related issue: https://github.com/numpy/numpy/issues/21539
54-
55-
if array1.columns != array2.columns:
56-
return False
57-
58-
for column in array1.columns:
59-
column_dtype = array1.dtype[column]
60-
if np.issubdtype(column_dtype, np.str_):
61-
if not np.array_equal(array1[column], array2[column]):
62-
return False
63-
continue
64-
if not np.array_equal(array1[column], array2[column], equal_nan=True):
65-
return False
66-
return True

src/power_grid_model_ds/_core/utils/misc.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,24 @@ def get_inherited_attrs(cls: Type, *private_attributes):
3939
retrieved_attributes[private_attr] = attr_dict
4040

4141
return retrieved_attributes
42+
43+
44+
def array_equal_with_nan(array1: np.ndarray, array2: np.ndarray) -> bool:
45+
"""Compare two structured arrays for equality, treating NaN values as equal.
46+
47+
np.array_equal does not work with NaN values in structured arrays, so we need to compare column by column.
48+
related issue: https://github.com/numpy/numpy/issues/21539
49+
"""
50+
if array1.dtype.names != array2.dtype.names:
51+
return False
52+
53+
columns: Sequence[str] = array1.dtype.names
54+
for column in columns:
55+
column_dtype = array1.dtype[column]
56+
if np.issubdtype(column_dtype, np.str_):
57+
if not np.array_equal(array1[column], array2[column]):
58+
return False
59+
continue
60+
if not np.array_equal(array1[column], array2[column], equal_nan=True):
61+
return False
62+
return True

0 commit comments

Comments
 (0)