Skip to content

Commit a03329a

Browse files
authored
Add typehint overloads for __getitem__ (#109)
* Make __getitem__ typesafe by restricting input Signed-off-by: Thijs Baaijen <[email protected]> * revert some changes Signed-off-by: Thijs Baaijen <[email protected]> * revert some changes Signed-off-by: Thijs Baaijen <[email protected]> * merge main Signed-off-by: Thijs Baaijen <[email protected]> --------- Signed-off-by: Thijs Baaijen <[email protected]>
1 parent 4877fd8 commit a03329a

File tree

2 files changed

+47
-21
lines changed

2 files changed

+47
-21
lines changed

src/power_grid_model_ds/_core/model/arrays/base/array.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections import namedtuple
77
from copy import copy
88
from functools import lru_cache
9-
from typing import Any, Iterable, Literal, Type, TypeVar
9+
from typing import Any, Iterable, Literal, Type, TypeVar, overload
1010

1111
import numpy as np
1212
from numpy.typing import ArrayLike, NDArray
@@ -152,20 +152,33 @@ def __setattr__(self: Self, attr: str, value: object) -> None:
152152
except (AttributeError, ValueError) as error:
153153
raise AttributeError(f"Cannot set attribute {attr} on {self.__class__.__name__}") from error
154154

155-
def __getitem__(self: Self, item):
156-
"""Used by for-loops, slicing [0:3], column-access ['id'], row-access [0], multi-column access.
157-
Note: If a single item is requested, return a named tuple instead of a np.void object.
158-
"""
159-
160-
result = self._data.__getitem__(item)
161-
162-
if isinstance(item, (list, tuple)) and (len(item) == 0 or np.array(item).dtype.type is np.bool_):
163-
return self.__class__(data=result)
164-
if isinstance(item, (str, list, tuple)):
165-
return result
166-
if isinstance(result, np.void):
167-
return self.__class__(data=np.array([result]))
168-
return self.__class__(data=result)
155+
@overload
156+
def __getitem__(
157+
self: Self, item: slice | int | NDArray[np.bool_] | list[bool] | NDArray[np.int_] | list[int]
158+
) -> Self: ...
159+
160+
@overload
161+
def __getitem__(self, item: str | NDArray[np.str_] | list[str]) -> NDArray[Any]: ...
162+
163+
def __getitem__(self, item):
164+
if isinstance(item, slice | int):
165+
new_data = self._data[item]
166+
if new_data.shape == ():
167+
new_data = np.array([new_data])
168+
return self.__class__(data=new_data)
169+
if isinstance(item, str):
170+
return self._data[item]
171+
if (isinstance(item, np.ndarray) and item.size == 0) or (isinstance(item, list | tuple) and len(item) == 0):
172+
return self.__class__(data=self._data[[]])
173+
if isinstance(item, list | np.ndarray):
174+
item_array = np.array(item)
175+
if item_array.dtype == np.bool_ or np.issubdtype(item_array.dtype, np.int_):
176+
return self.__class__(data=self._data[item_array])
177+
if np.issubdtype(item_array.dtype, np.str_):
178+
return self._data[item_array.tolist()]
179+
raise NotImplementedError(
180+
f"FancyArray[{type(item).__name__}] is not supported. Try FancyArray.data[{type(item).__name__}] instead."
181+
)
169182

170183
def __setitem__(self: Self, key, value):
171184
if isinstance(value, FancyArray):

tests/unit/model/arrays/test_array.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,17 @@ def test_getitem_unique_multiple_columns(fancy_test_array: FancyTestArray):
7474
assert np.array_equal(np.unique(fancy_test_array[columns]), fancy_test_array[columns])
7575

7676

77+
def test_getitem_array_index(fancy_test_array: FancyTestArray):
78+
assert fancy_test_array[0].data.tolist() == fancy_test_array.data[0:1].tolist()
79+
80+
81+
def test_getitem_array_nested_index(fancy_test_array: FancyTestArray):
82+
nested_array = fancy_test_array[0][0][0][0][0][0]
83+
assert isinstance(nested_array, FancyArray)
84+
assert nested_array.data.shape == (1,)
85+
assert nested_array.data.tolist() == fancy_test_array.data[0:1].tolist()
86+
87+
7788
def test_getitem_array_slice(fancy_test_array: FancyTestArray):
7889
assert fancy_test_array.data[0:2].tolist() == fancy_test_array[0:2].tolist()
7990

@@ -84,18 +95,20 @@ def test_getitem_with_array_mask(fancy_test_array: FancyTestArray):
8495
assert np.array_equal(fancy_test_array.data[mask], fancy_test_array[mask].data)
8596

8697

87-
def test_getitem_with_tuple_mask(fancy_test_array: FancyTestArray):
88-
mask = (True, False, True)
89-
assert isinstance(fancy_test_array[mask], FancyArray)
90-
assert np.array_equal(fancy_test_array.data[mask], fancy_test_array[mask].data)
91-
92-
9398
def test_getitem_with_list_mask(fancy_test_array: FancyTestArray):
9499
mask = [True, False, True]
95100
assert isinstance(fancy_test_array[mask], FancyArray)
96101
assert np.array_equal(fancy_test_array.data[mask], fancy_test_array[mask].data)
97102

98103

104+
def test_getitem_with_tuple_mask(fancy_test_array: FancyTestArray):
105+
# Numpy gives unexpected results with tuple masks. Therefore, we raise NotImplementedError here.
106+
# e.g: np.array([1,2,3])[(True, False, True)] returns an empty array (array([], shape=(0, 3), dtype=int64)
107+
mask = (True, False, True)
108+
with pytest.raises(NotImplementedError):
109+
fancy_test_array[mask] # type: ignore[call-overload] # noqa
110+
111+
99112
def test_getitem_with_empty_list_mask():
100113
array = FancyTestArray()
101114
mask = []

0 commit comments

Comments
 (0)