From c4714a3e8e5bceb2bb4eecf59acefe4abafb6c9a Mon Sep 17 00:00:00 2001 From: Thijs Baaijen <13253091+Thijss@users.noreply.github.com> Date: Thu, 17 Jul 2025 10:07:59 +0200 Subject: [PATCH] Separate array_equal_with_nan Signed-off-by: Thijs Baaijen <13253091+Thijss@users.noreply.github.com> --- src/power_grid_model_ds/_core/fancypy.py | 22 +++------------------ src/power_grid_model_ds/_core/utils/misc.py | 21 ++++++++++++++++++++ 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/src/power_grid_model_ds/_core/fancypy.py b/src/power_grid_model_ds/_core/fancypy.py index a037e4f..ef4dc05 100644 --- a/src/power_grid_model_ds/_core/fancypy.py +++ b/src/power_grid_model_ds/_core/fancypy.py @@ -8,6 +8,8 @@ import numpy as np +from power_grid_model_ds._core.utils.misc import array_equal_with_nan + if TYPE_CHECKING: from power_grid_model_ds._core.model.arrays.base.array import FancyArray @@ -44,23 +46,5 @@ def sort(array: "FancyArray", axis=-1, kind=None, order=None) -> "FancyArray": def array_equal(array1: "FancyArray", array2: "FancyArray", equal_nan: bool = True) -> bool: """Return True if two arrays are equal.""" if equal_nan: - return _array_equal_with_nan(array1, array2) + return array_equal_with_nan(array1.data, array2.data) return np.array_equal(array1.data, array2.data) - - -def _array_equal_with_nan(array1: "FancyArray", array2: "FancyArray") -> bool: - # np.array_equal does not work with NaN values in structured arrays, so we need to compare column by column. - # related issue: https://github.com/numpy/numpy/issues/21539 - - if array1.columns != array2.columns: - return False - - for column in array1.columns: - column_dtype = array1.dtype[column] - if np.issubdtype(column_dtype, np.str_): - if not np.array_equal(array1[column], array2[column]): - return False - continue - if not np.array_equal(array1[column], array2[column], equal_nan=True): - return False - return True diff --git a/src/power_grid_model_ds/_core/utils/misc.py b/src/power_grid_model_ds/_core/utils/misc.py index 0eeb64f..bf64df4 100644 --- a/src/power_grid_model_ds/_core/utils/misc.py +++ b/src/power_grid_model_ds/_core/utils/misc.py @@ -39,3 +39,24 @@ def get_inherited_attrs(cls: Type, *private_attributes): retrieved_attributes[private_attr] = attr_dict return retrieved_attributes + + +def array_equal_with_nan(array1: np.ndarray, array2: np.ndarray) -> bool: + """Compare two structured arrays for equality, treating NaN values as equal. + + np.array_equal does not work with NaN values in structured arrays, so we need to compare column by column. + related issue: https://github.com/numpy/numpy/issues/21539 + """ + if array1.dtype.names != array2.dtype.names: + return False + + columns: Sequence[str] = array1.dtype.names + for column in columns: + column_dtype = array1.dtype[column] + if np.issubdtype(column_dtype, np.str_): + if not np.array_equal(array1[column], array2[column]): + return False + continue + if not np.array_equal(array1[column], array2[column], equal_nan=True): + return False + return True