Skip to content

Commit bb87fb5

Browse files
committed
Fix typing issues
Signed-off-by: Thijs Baaijen <[email protected]>
1 parent 8c4a8a4 commit bb87fb5

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

src/power_grid_model_ds/_core/fancypy.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
"""A set of helper functions that mimic numpy functions but are specifically designed for FancyArrays."""
66

7-
from typing import TYPE_CHECKING, Union
7+
from typing import TYPE_CHECKING, TypeVar, Union
88

99
import numpy as np
1010

@@ -13,8 +13,13 @@
1313
if TYPE_CHECKING:
1414
from power_grid_model_ds._core.model.arrays.base.array import FancyArray
1515

16+
T = TypeVar("T", bound="FancyArray")
1617

17-
def concatenate(fancy_array: "FancyArray", *other_arrays: Union["FancyArray", np.ndarray]) -> "FancyArray":
18+
if TYPE_CHECKING:
19+
pass
20+
21+
22+
def concatenate(fancy_array: T, *other_arrays: Union[T, np.ndarray]) -> T:
1823
"""Concatenate arrays."""
1924
np_arrays = [array if isinstance(array, np.ndarray) else array.data for array in other_arrays]
2025
try:
@@ -24,7 +29,7 @@ def concatenate(fancy_array: "FancyArray", *other_arrays: Union["FancyArray", np
2429
return fancy_array.__class__(data=concatenated)
2530

2631

27-
def unique(array: "FancyArray", **kwargs):
32+
def unique(array: T, **kwargs):
2833
"""Return the unique elements of the array."""
2934
for column in array.columns:
3035
if np.issubdtype(array.dtype[column], np.floating) and np.isnan(array[column]).any():
@@ -37,13 +42,13 @@ def unique(array: "FancyArray", **kwargs):
3742
return array.__class__(data=unique_data)
3843

3944

40-
def sort(array: "FancyArray", axis=-1, kind=None, order=None) -> "FancyArray":
45+
def sort(array: T, axis=-1, kind=None, order=None) -> T:
4146
"""Sort the array in-place and return sorted array."""
4247
array.data.sort(axis=axis, kind=kind, order=order)
4348
return array
4449

4550

46-
def array_equal(array1: "FancyArray", array2: "FancyArray", equal_nan: bool = True) -> bool:
51+
def array_equal(array1: T, array2: T, equal_nan: bool = True) -> bool:
4752
"""Return True if two arrays are equal."""
4853
if equal_nan:
4954
return array_equal_with_nan(array1.data, array2.data)

0 commit comments

Comments
 (0)