Skip to content

Commit ab3dd19

Browse files
committed
revert some changes
Signed-off-by: Thijs Baaijen <[email protected]>
1 parent a20d678 commit ab3dd19

File tree

3 files changed

+42
-27
lines changed

3 files changed

+42
-27
lines changed

src/power_grid_model_ds/_core/model/arrays/base/array.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -153,24 +153,32 @@ def __setattr__(self: Self, attr: str, value: object) -> None:
153153
raise AttributeError(f"Cannot set attribute {attr} on {self.__class__.__name__}") from error
154154

155155
@overload
156-
def __getitem__(self: Self, item: slice | int | NDArray[np.bool_]) -> Self: ...
156+
def __getitem__(
157+
self: Self, item: slice | int | NDArray[np.bool_] | list[bool] | NDArray[np.int_] | list[int]
158+
) -> Self: ...
157159

158160
@overload
159-
def __getitem__(self, item: str | NDArray[np.str_]) -> NDArray[Any]: ...
161+
def __getitem__(self, item: str | NDArray[np.str_] | list[str]) -> NDArray[Any]: ...
160162

161163
def __getitem__(self, item):
162164
if isinstance(item, slice | int):
163165
new_data = self._data[item]
164166
if new_data.shape == ():
165167
new_data = np.array([new_data])
166168
return self.__class__(data=new_data)
167-
if isinstance(item, np.ndarray) and item.dtype == np.bool_:
168-
return self.__class__(data=self._data[item])
169-
if isinstance(item, np.ndarray) and item.size == 0:
170-
return self.__class__(data=self._data[[]])
171169
if isinstance(item, str):
172170
return self._data[item]
173-
raise NotImplementedError(f"FancyArray[{type(item)}] is not supported. Use FancyArray.data instead.")
171+
if (isinstance(item, np.ndarray) and item.size == 0) or (isinstance(item, list | tuple) and len(item) == 0):
172+
return self.__class__(data=self._data[[]])
173+
if isinstance(item, list | np.ndarray):
174+
item_array = np.array(item)
175+
if item_array.dtype == np.bool_ or np.issubdtype(item_array.dtype, np.int_):
176+
return self.__class__(data=self._data[item_array])
177+
if np.issubdtype(item_array.dtype, np.str_):
178+
return self._data[item_array.tolist()]
179+
raise NotImplementedError(
180+
f"FancyArray[{type(item).__name__}] is not supported. Try FancyArray.data[{type(item).__name__}] instead."
181+
)
174182

175183
def __setitem__(self: Self, key, value):
176184
if isinstance(value, FancyArray):
@@ -335,4 +343,4 @@ def from_extended(cls: Type[Self], extended: Self) -> Self:
335343
if not isinstance(extended, cls):
336344
raise TypeError(f"Extended array must be of type {cls.__name__}, got {type(extended).__name__}")
337345
dtype = cls.get_dtype()
338-
return cls(data=np.array(extended.data[list(dtype.names)], dtype=dtype))
346+
return cls(data=np.array(extended[list(dtype.names)], dtype=dtype))

src/power_grid_model_ds/_core/model/arrays/pgm_arrays.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def filter_parallel(self, n_parallel: int, mode: Literal["eq", "neq"]) -> "Branc
7171
- when n_parallel is 1 and mode is 'eq', the function returns branches that are not parallel.
7272
- when n_parallel is 1 and mode is 'neq', the function returns branches that are parallel.
7373
"""
74-
_, index, counts = np.unique(self.data[["from_node", "to_node"]], return_counts=True, return_index=True)
74+
_, index, counts = np.unique(self[["from_node", "to_node"]], return_counts=True, return_index=True)
7575

7676
match mode:
7777
case "eq":
@@ -82,9 +82,9 @@ def filter_parallel(self, n_parallel: int, mode: Literal["eq", "neq"]) -> "Branc
8282
raise ValueError(f"mode {mode} not supported")
8383

8484
if mode == "eq" and n_parallel == 1:
85-
return self.__class__(self.data[index][counts_mask])
86-
filtered_branches = self.data[index][counts_mask]
87-
return self.filter(from_node=filtered_branches["from_node"], to_node=filtered_branches["to_node"])
85+
return self[index][counts_mask]
86+
filtered_branches = self[index][counts_mask]
87+
return self.filter(from_node=filtered_branches.from_node, to_node=filtered_branches.to_node)
8888

8989

9090
class LinkArray(Link, BranchArray):

tests/unit/model/arrays/test_array.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,12 @@ def test_getitem_array_one_column(fancy_test_array: FancyTestArray):
6565

6666
def test_getitem_array_multiple_columns(fancy_test_array: FancyTestArray):
6767
columns = ["id", "test_int", "test_float"]
68-
assert_array_equal(fancy_test_array.data[columns].dtype.names, ("id", "test_int", "test_float"))
68+
assert_array_equal(fancy_test_array[columns].dtype.names, ("id", "test_int", "test_float"))
69+
70+
71+
def test_getitem_unique_multiple_columns(fancy_test_array: FancyTestArray):
72+
columns = ["id", "test_int", "test_float"]
73+
assert np.array_equal(np.unique(fancy_test_array[columns]), fancy_test_array[columns])
6974

7075

7176
def test_getitem_array_index(fancy_test_array: FancyTestArray):
@@ -80,7 +85,7 @@ def test_getitem_array_nested_index(fancy_test_array: FancyTestArray):
8085

8186

8287
def test_getitem_array_slice(fancy_test_array: FancyTestArray):
83-
assert fancy_test_array[0:2].data.tolist() == fancy_test_array.data[0:2].tolist()
88+
assert fancy_test_array.data[0:2].tolist() == fancy_test_array[0:2].tolist()
8489

8590

8691
def test_getitem_with_array_mask(fancy_test_array: FancyTestArray):
@@ -89,21 +94,23 @@ def test_getitem_with_array_mask(fancy_test_array: FancyTestArray):
8994
assert np.array_equal(fancy_test_array.data[mask], fancy_test_array[mask].data)
9095

9196

92-
def test_getitem_with_tuple_mask(fancy_test_array: FancyTestArray):
93-
mask = (True, False, True)
94-
with pytest.raises(NotImplementedError):
95-
fancy_test_array[mask] # type: ignore[call-overload] # noqa
96-
97-
9897
def test_getitem_with_list_mask(fancy_test_array: FancyTestArray):
9998
mask = [True, False, True]
99+
assert isinstance(fancy_test_array[mask], FancyArray)
100+
assert np.array_equal(fancy_test_array.data[mask], fancy_test_array[mask].data)
101+
102+
103+
def test_getitem_with_tuple_mask(fancy_test_array: FancyTestArray):
104+
# Numpy gives unexpected results with tuple masks. Therefore, we raise NotImplementedError here.
105+
# e.g: np.array([1,2,3])[(True, False, True)] returns an empty array (array([], shape=(0, 3), dtype=int64)
106+
mask = (True, False, True)
100107
with pytest.raises(NotImplementedError):
101108
fancy_test_array[mask] # type: ignore[call-overload] # noqa
102109

103110

104111
def test_getitem_with_empty_list_mask():
105112
array = FancyTestArray()
106-
mask = np.array([], dtype=bool)
113+
mask = []
107114
assert isinstance(array[mask], FancyArray)
108115
assert np.array_equal(array.data[mask], array[mask].data)
109116

@@ -238,16 +245,16 @@ def test_unique_return_counts_and_inverse(fancy_test_array: FancyTestArray):
238245

239246

240247
def test_sort(fancy_test_array: FancyTestArray):
241-
assert_array_equal(fancy_test_array["test_float"], [4.0, 4.0, 1.0])
242-
fancy_test_array.data.sort(order="test_float")
243-
assert_array_equal(fancy_test_array["test_float"], [1.0, 4.0, 4.0])
248+
assert_array_equal(fancy_test_array.test_float, [4.0, 4.0, 1.0])
249+
fancy_test_array.sort(order="test_float")
250+
assert_array_equal(fancy_test_array.test_float, [1.0, 4.0, 4.0])
244251

245252

246253
def test_copy_function(fancy_test_array: FancyTestArray):
247254
array_copy = copy(fancy_test_array)
248-
array_copy["test_int"] = 123
255+
array_copy.test_int = 123
249256
assert not id(fancy_test_array) == id(array_copy)
250-
assert not fancy_test_array["test_int"][0] == array_copy["test_int"][0]
257+
assert not fancy_test_array.test_int[0] == array_copy.test_int[0]
251258

252259

253260
def test_copy_method(fancy_test_array: FancyTestArray):
@@ -307,4 +314,4 @@ def test_from_extended_array():
307314

308315
array = LineArray.from_extended(extended_array)
309316
assert not isinstance(array, ExtendedLineArray)
310-
array_equal_with_nan(array.data, extended_array.data[array.columns])
317+
array_equal_with_nan(array.data, extended_array[array.columns])

0 commit comments

Comments
 (0)