diff --git a/src/power_grid_model_ds/_core/model/arrays/base/array.py b/src/power_grid_model_ds/_core/model/arrays/base/array.py index bab8452..b62aa74 100644 --- a/src/power_grid_model_ds/_core/model/arrays/base/array.py +++ b/src/power_grid_model_ds/_core/model/arrays/base/array.py @@ -6,7 +6,7 @@ from collections import namedtuple from copy import copy from functools import lru_cache -from typing import Any, Iterable, Literal, Type, TypeVar +from typing import Any, Iterable, Literal, Type, TypeVar, overload import numpy as np from numpy.typing import ArrayLike, NDArray @@ -152,20 +152,33 @@ def __setattr__(self: Self, attr: str, value: object) -> None: except (AttributeError, ValueError) as error: raise AttributeError(f"Cannot set attribute {attr} on {self.__class__.__name__}") from error - def __getitem__(self: Self, item): - """Used by for-loops, slicing [0:3], column-access ['id'], row-access [0], multi-column access. - Note: If a single item is requested, return a named tuple instead of a np.void object. - """ - - result = self._data.__getitem__(item) - - if isinstance(item, (list, tuple)) and (len(item) == 0 or np.array(item).dtype.type is np.bool_): - return self.__class__(data=result) - if isinstance(item, (str, list, tuple)): - return result - if isinstance(result, np.void): - return self.__class__(data=np.array([result])) - return self.__class__(data=result) + @overload + def __getitem__( + self: Self, item: slice | int | NDArray[np.bool_] | list[bool] | NDArray[np.int_] | list[int] + ) -> Self: ... + + @overload + def __getitem__(self, item: str | NDArray[np.str_] | list[str]) -> NDArray[Any]: ... + + def __getitem__(self, item): + if isinstance(item, slice | int): + new_data = self._data[item] + if new_data.shape == (): + new_data = np.array([new_data]) + return self.__class__(data=new_data) + if isinstance(item, str): + return self._data[item] + if (isinstance(item, np.ndarray) and item.size == 0) or (isinstance(item, list | tuple) and len(item) == 0): + return self.__class__(data=self._data[[]]) + if isinstance(item, list | np.ndarray): + item_array = np.array(item) + if item_array.dtype == np.bool_ or np.issubdtype(item_array.dtype, np.int_): + return self.__class__(data=self._data[item_array]) + if np.issubdtype(item_array.dtype, np.str_): + return self._data[item_array.tolist()] + raise NotImplementedError( + f"FancyArray[{type(item).__name__}] is not supported. Try FancyArray.data[{type(item).__name__}] instead." + ) def __setitem__(self: Self, key, value): if isinstance(value, FancyArray): diff --git a/tests/unit/model/arrays/test_array.py b/tests/unit/model/arrays/test_array.py index 8e4b7c5..c727687 100644 --- a/tests/unit/model/arrays/test_array.py +++ b/tests/unit/model/arrays/test_array.py @@ -74,6 +74,17 @@ def test_getitem_unique_multiple_columns(fancy_test_array: FancyTestArray): assert np.array_equal(np.unique(fancy_test_array[columns]), fancy_test_array[columns]) +def test_getitem_array_index(fancy_test_array: FancyTestArray): + assert fancy_test_array[0].data.tolist() == fancy_test_array.data[0:1].tolist() + + +def test_getitem_array_nested_index(fancy_test_array: FancyTestArray): + nested_array = fancy_test_array[0][0][0][0][0][0] + assert isinstance(nested_array, FancyArray) + assert nested_array.data.shape == (1,) + assert nested_array.data.tolist() == fancy_test_array.data[0:1].tolist() + + def test_getitem_array_slice(fancy_test_array: FancyTestArray): assert fancy_test_array.data[0:2].tolist() == fancy_test_array[0:2].tolist() @@ -84,18 +95,20 @@ def test_getitem_with_array_mask(fancy_test_array: FancyTestArray): assert np.array_equal(fancy_test_array.data[mask], fancy_test_array[mask].data) -def test_getitem_with_tuple_mask(fancy_test_array: FancyTestArray): - mask = (True, False, True) - assert isinstance(fancy_test_array[mask], FancyArray) - assert np.array_equal(fancy_test_array.data[mask], fancy_test_array[mask].data) - - def test_getitem_with_list_mask(fancy_test_array: FancyTestArray): mask = [True, False, True] assert isinstance(fancy_test_array[mask], FancyArray) assert np.array_equal(fancy_test_array.data[mask], fancy_test_array[mask].data) +def test_getitem_with_tuple_mask(fancy_test_array: FancyTestArray): + # Numpy gives unexpected results with tuple masks. Therefore, we raise NotImplementedError here. + # e.g: np.array([1,2,3])[(True, False, True)] returns an empty array (array([], shape=(0, 3), dtype=int64) + mask = (True, False, True) + with pytest.raises(NotImplementedError): + fancy_test_array[mask] # type: ignore[call-overload] # noqa + + def test_getitem_with_empty_list_mask(): array = FancyTestArray() mask = []