diff --git a/src/power_grid_model_ds/_core/fancypy.py b/src/power_grid_model_ds/_core/fancypy.py index ef4dc05..793618b 100644 --- a/src/power_grid_model_ds/_core/fancypy.py +++ b/src/power_grid_model_ds/_core/fancypy.py @@ -4,17 +4,19 @@ """A set of helper functions that mimic numpy functions but are specifically designed for FancyArrays.""" -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, TypeVar, Union import numpy as np from power_grid_model_ds._core.utils.misc import array_equal_with_nan if TYPE_CHECKING: - from power_grid_model_ds._core.model.arrays.base.array import FancyArray + from power_grid_model_ds._core.model.arrays.base.array import FancyArray # noqa +T = TypeVar("T", bound="FancyArray") -def concatenate(fancy_array: "FancyArray", *other_arrays: Union["FancyArray", np.ndarray]) -> "FancyArray": + +def concatenate(fancy_array: T, *other_arrays: Union[T, np.ndarray]) -> T: """Concatenate arrays.""" np_arrays = [array if isinstance(array, np.ndarray) else array.data for array in other_arrays] try: @@ -24,7 +26,7 @@ def concatenate(fancy_array: "FancyArray", *other_arrays: Union["FancyArray", np return fancy_array.__class__(data=concatenated) -def unique(array: "FancyArray", **kwargs): +def unique(array: T, **kwargs): """Return the unique elements of the array.""" for column in array.columns: if np.issubdtype(array.dtype[column], np.floating) and np.isnan(array[column]).any(): @@ -37,13 +39,13 @@ def unique(array: "FancyArray", **kwargs): return array.__class__(data=unique_data) -def sort(array: "FancyArray", axis=-1, kind=None, order=None) -> "FancyArray": +def sort(array: T, axis=-1, kind=None, order=None) -> T: """Sort the array in-place and return sorted array.""" array.data.sort(axis=axis, kind=kind, order=order) return array -def array_equal(array1: "FancyArray", array2: "FancyArray", equal_nan: bool = True) -> bool: +def array_equal(array1: T, array2: T, equal_nan: bool = True) -> bool: """Return True if two arrays are equal.""" if equal_nan: return array_equal_with_nan(array1.data, array2.data)