Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/power_grid_model_ds/_core/fancypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@

"""A set of helper functions that mimic numpy functions but are specifically designed for FancyArrays."""

from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING, TypeVar, Union

import numpy as np

from power_grid_model_ds._core.utils.misc import array_equal_with_nan

if TYPE_CHECKING:
from power_grid_model_ds._core.model.arrays.base.array import FancyArray
from power_grid_model_ds._core.model.arrays.base.array import FancyArray # noqa

T = TypeVar("T", bound="FancyArray")

def concatenate(fancy_array: "FancyArray", *other_arrays: Union["FancyArray", np.ndarray]) -> "FancyArray":

def concatenate(fancy_array: T, *other_arrays: Union[T, np.ndarray]) -> T:
"""Concatenate arrays."""
np_arrays = [array if isinstance(array, np.ndarray) else array.data for array in other_arrays]
try:
Expand All @@ -24,7 +26,7 @@ def concatenate(fancy_array: "FancyArray", *other_arrays: Union["FancyArray", np
return fancy_array.__class__(data=concatenated)


def unique(array: "FancyArray", **kwargs):
def unique(array: T, **kwargs):
"""Return the unique elements of the array."""
for column in array.columns:
if np.issubdtype(array.dtype[column], np.floating) and np.isnan(array[column]).any():
Expand All @@ -37,13 +39,13 @@ def unique(array: "FancyArray", **kwargs):
return array.__class__(data=unique_data)


def sort(array: "FancyArray", axis=-1, kind=None, order=None) -> "FancyArray":
def sort(array: T, axis=-1, kind=None, order=None) -> T:
"""Sort the array in-place and return sorted array."""
array.data.sort(axis=axis, kind=kind, order=order)
return array


def array_equal(array1: "FancyArray", array2: "FancyArray", equal_nan: bool = True) -> bool:
def array_equal(array1: T, array2: T, equal_nan: bool = True) -> bool:
"""Return True if two arrays are equal."""
if equal_nan:
return array_equal_with_nan(array1.data, array2.data)
Expand Down
Loading