Skip to content

Commit a20d678

Browse files
committed
Make __getitem__ typesafe by restricting input
Signed-off-by: Thijs Baaijen <[email protected]>
1 parent 6b06c3e commit a20d678

File tree

4 files changed

+49
-39
lines changed

4 files changed

+49
-39
lines changed

src/power_grid_model_ds/_core/model/arrays/base/array.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections import namedtuple
77
from copy import copy
88
from functools import lru_cache
9-
from typing import Any, Iterable, Literal, Type, TypeVar
9+
from typing import Any, Iterable, Literal, Type, TypeVar, overload
1010

1111
import numpy as np
1212
from numpy.typing import ArrayLike, NDArray
@@ -152,20 +152,25 @@ def __setattr__(self: Self, attr: str, value: object) -> None:
152152
except (AttributeError, ValueError) as error:
153153
raise AttributeError(f"Cannot set attribute {attr} on {self.__class__.__name__}") from error
154154

155-
def __getitem__(self: Self, item):
156-
"""Used by for-loops, slicing [0:3], column-access ['id'], row-access [0], multi-column access.
157-
Note: If a single item is requested, return a named tuple instead of a np.void object.
158-
"""
159-
160-
result = self._data.__getitem__(item)
161-
162-
if isinstance(item, (list, tuple)) and (len(item) == 0 or np.array(item).dtype.type is np.bool_):
163-
return self.__class__(data=result)
164-
if isinstance(item, (str, list, tuple)):
165-
return result
166-
if isinstance(result, np.void):
167-
return self.__class__(data=np.array([result]))
168-
return self.__class__(data=result)
155+
@overload
156+
def __getitem__(self: Self, item: slice | int | NDArray[np.bool_]) -> Self: ...
157+
158+
@overload
159+
def __getitem__(self, item: str | NDArray[np.str_]) -> NDArray[Any]: ...
160+
161+
def __getitem__(self, item):
162+
if isinstance(item, slice | int):
163+
new_data = self._data[item]
164+
if new_data.shape == ():
165+
new_data = np.array([new_data])
166+
return self.__class__(data=new_data)
167+
if isinstance(item, np.ndarray) and item.dtype == np.bool_:
168+
return self.__class__(data=self._data[item])
169+
if isinstance(item, np.ndarray) and item.size == 0:
170+
return self.__class__(data=self._data[[]])
171+
if isinstance(item, str):
172+
return self._data[item]
173+
raise NotImplementedError(f"FancyArray[{type(item)}] is not supported. Use FancyArray.data instead.")
169174

170175
def __setitem__(self: Self, key, value):
171176
if isinstance(value, FancyArray):
@@ -330,4 +335,4 @@ def from_extended(cls: Type[Self], extended: Self) -> Self:
330335
if not isinstance(extended, cls):
331336
raise TypeError(f"Extended array must be of type {cls.__name__}, got {type(extended).__name__}")
332337
dtype = cls.get_dtype()
333-
return cls(data=np.array(extended[list(dtype.names)], dtype=dtype))
338+
return cls(data=np.array(extended.data[list(dtype.names)], dtype=dtype))

src/power_grid_model_ds/_core/model/arrays/pgm_arrays.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def filter_parallel(self, n_parallel: int, mode: Literal["eq", "neq"]) -> "Branc
7171
- when n_parallel is 1 and mode is 'eq', the function returns branches that are not parallel.
7272
- when n_parallel is 1 and mode is 'neq', the function returns branches that are parallel.
7373
"""
74-
_, index, counts = np.unique(self[["from_node", "to_node"]], return_counts=True, return_index=True)
74+
_, index, counts = np.unique(self.data[["from_node", "to_node"]], return_counts=True, return_index=True)
7575

7676
match mode:
7777
case "eq":
@@ -82,9 +82,9 @@ def filter_parallel(self, n_parallel: int, mode: Literal["eq", "neq"]) -> "Branc
8282
raise ValueError(f"mode {mode} not supported")
8383

8484
if mode == "eq" and n_parallel == 1:
85-
return self[index][counts_mask]
86-
filtered_branches = self[index][counts_mask]
87-
return self.filter(from_node=filtered_branches.from_node, to_node=filtered_branches.to_node)
85+
return self.__class__(self.data[index][counts_mask])
86+
filtered_branches = self.data[index][counts_mask]
87+
return self.filter(from_node=filtered_branches["from_node"], to_node=filtered_branches["to_node"])
8888

8989

9090
class LinkArray(Link, BranchArray):

tests/unit/model/arrays/test_array.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,22 @@ def test_getitem_array_one_column(fancy_test_array: FancyTestArray):
6565

6666
def test_getitem_array_multiple_columns(fancy_test_array: FancyTestArray):
6767
columns = ["id", "test_int", "test_float"]
68-
assert fancy_test_array.data[columns].tolist() == fancy_test_array[columns].tolist()
69-
assert_array_equal(fancy_test_array[columns].dtype.names, ("id", "test_int", "test_float"))
68+
assert_array_equal(fancy_test_array.data[columns].dtype.names, ("id", "test_int", "test_float"))
7069

7170

72-
def test_getitem_unique_multiple_columns(fancy_test_array: FancyTestArray):
73-
columns = ["id", "test_int", "test_float"]
74-
assert np.array_equal(np.unique(fancy_test_array[columns]), fancy_test_array[columns])
71+
def test_getitem_array_index(fancy_test_array: FancyTestArray):
72+
assert fancy_test_array[0].data.tolist() == fancy_test_array.data[0:1].tolist()
73+
74+
75+
def test_getitem_array_nested_index(fancy_test_array: FancyTestArray):
76+
nested_array = fancy_test_array[0][0][0][0][0][0]
77+
assert isinstance(nested_array, FancyArray)
78+
assert nested_array.data.shape == (1,)
79+
assert nested_array.data.tolist() == fancy_test_array.data[0:1].tolist()
7580

7681

7782
def test_getitem_array_slice(fancy_test_array: FancyTestArray):
78-
assert fancy_test_array.data[0:2].tolist() == fancy_test_array[0:2].tolist()
83+
assert fancy_test_array[0:2].data.tolist() == fancy_test_array.data[0:2].tolist()
7984

8085

8186
def test_getitem_with_array_mask(fancy_test_array: FancyTestArray):
@@ -86,19 +91,19 @@ def test_getitem_with_array_mask(fancy_test_array: FancyTestArray):
8691

8792
def test_getitem_with_tuple_mask(fancy_test_array: FancyTestArray):
8893
mask = (True, False, True)
89-
assert isinstance(fancy_test_array[mask], FancyArray)
90-
assert np.array_equal(fancy_test_array.data[mask], fancy_test_array[mask].data)
94+
with pytest.raises(NotImplementedError):
95+
fancy_test_array[mask] # type: ignore[call-overload] # noqa
9196

9297

9398
def test_getitem_with_list_mask(fancy_test_array: FancyTestArray):
9499
mask = [True, False, True]
95-
assert isinstance(fancy_test_array[mask], FancyArray)
96-
assert np.array_equal(fancy_test_array.data[mask], fancy_test_array[mask].data)
100+
with pytest.raises(NotImplementedError):
101+
fancy_test_array[mask] # type: ignore[call-overload] # noqa
97102

98103

99104
def test_getitem_with_empty_list_mask():
100105
array = FancyTestArray()
101-
mask = []
106+
mask = np.array([], dtype=bool)
102107
assert isinstance(array[mask], FancyArray)
103108
assert np.array_equal(array.data[mask], array[mask].data)
104109

@@ -233,16 +238,16 @@ def test_unique_return_counts_and_inverse(fancy_test_array: FancyTestArray):
233238

234239

235240
def test_sort(fancy_test_array: FancyTestArray):
236-
assert_array_equal(fancy_test_array.test_float, [4.0, 4.0, 1.0])
237-
fancy_test_array.sort(order="test_float")
238-
assert_array_equal(fancy_test_array.test_float, [1.0, 4.0, 4.0])
241+
assert_array_equal(fancy_test_array["test_float"], [4.0, 4.0, 1.0])
242+
fancy_test_array.data.sort(order="test_float")
243+
assert_array_equal(fancy_test_array["test_float"], [1.0, 4.0, 4.0])
239244

240245

241246
def test_copy_function(fancy_test_array: FancyTestArray):
242247
array_copy = copy(fancy_test_array)
243-
array_copy.test_int = 123
248+
array_copy["test_int"] = 123
244249
assert not id(fancy_test_array) == id(array_copy)
245-
assert not fancy_test_array.test_int[0] == array_copy.test_int[0]
250+
assert not fancy_test_array["test_int"][0] == array_copy["test_int"][0]
246251

247252

248253
def test_copy_method(fancy_test_array: FancyTestArray):
@@ -302,4 +307,4 @@ def test_from_extended_array():
302307

303308
array = LineArray.from_extended(extended_array)
304309
assert not isinstance(array, ExtendedLineArray)
305-
array_equal_with_nan(array.data, extended_array[array.columns])
310+
array_equal_with_nan(array.data, extended_array.data[array.columns])

tests/unit/model/arrays/test_modify.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,10 @@ def test_concatenate_different_ndarray(self, fancy_test_array: FancyTestArray):
128128
fp.concatenate(fancy_test_array, different_array.data)
129129

130130
def test_concatenate_different_fancy_array_same_dtype(self, fancy_test_array: FancyTestArray):
131-
sub_array = fancy_test_array[["test_str", "test_int"]]
131+
sub_array = fancy_test_array.data[["test_str", "test_int"]]
132132

133133
different_array = FancyNonIdArray.zeros(10)
134-
different_sub_array = different_array[["test_str", "test_int"]]
134+
different_sub_array = different_array.data[["test_str", "test_int"]]
135135

136136
concatenated = np.concatenate([sub_array, different_sub_array])
137137
assert concatenated.size == 13

0 commit comments

Comments
 (0)