44
55"""A set of helper functions that mimic numpy functions but are specifically designed for FancyArrays."""
66
7- from typing import TYPE_CHECKING , Union
7+ from typing import TYPE_CHECKING , TypeVar , Union
88
99import numpy as np
1010
1313if TYPE_CHECKING :
1414 from power_grid_model_ds ._core .model .arrays .base .array import FancyArray
1515
16+ T = TypeVar ("T" , bound = "FancyArray" )
1617
17- def concatenate (fancy_array : "FancyArray" , * other_arrays : Union ["FancyArray" , np .ndarray ]) -> "FancyArray" :
18+ if TYPE_CHECKING :
19+ pass
20+
21+
22+ def concatenate (fancy_array : T , * other_arrays : Union [T , np .ndarray ]) -> T :
1823 """Concatenate arrays."""
1924 np_arrays = [array if isinstance (array , np .ndarray ) else array .data for array in other_arrays ]
2025 try :
@@ -24,7 +29,7 @@ def concatenate(fancy_array: "FancyArray", *other_arrays: Union["FancyArray", np
2429 return fancy_array .__class__ (data = concatenated )
2530
2631
27- def unique (array : "FancyArray" , ** kwargs ):
32+ def unique (array : T , ** kwargs ):
2833 """Return the unique elements of the array."""
2934 for column in array .columns :
3035 if np .issubdtype (array .dtype [column ], np .floating ) and np .isnan (array [column ]).any ():
@@ -37,13 +42,13 @@ def unique(array: "FancyArray", **kwargs):
3742 return array .__class__ (data = unique_data )
3843
3944
40- def sort (array : "FancyArray" , axis = - 1 , kind = None , order = None ) -> "FancyArray" :
45+ def sort (array : T , axis = - 1 , kind = None , order = None ) -> T :
4146 """Sort the array in-place and return sorted array."""
4247 array .data .sort (axis = axis , kind = kind , order = order )
4348 return array
4449
4550
46- def array_equal (array1 : "FancyArray" , array2 : "FancyArray" , equal_nan : bool = True ) -> bool :
51+ def array_equal (array1 : T , array2 : T , equal_nan : bool = True ) -> bool :
4752 """Return True if two arrays are equal."""
4853 if equal_nan :
4954 return array_equal_with_nan (array1 .data , array2 .data )
0 commit comments