|
6 | 6 | from collections import namedtuple |
7 | 7 | from copy import copy |
8 | 8 | from functools import lru_cache |
9 | | -from typing import Any, Iterable, Literal, Type, TypeVar |
| 9 | +from typing import Any, Iterable, Literal, Type, TypeVar, overload |
10 | 10 |
|
11 | 11 | import numpy as np |
12 | 12 | from numpy.typing import ArrayLike, NDArray |
@@ -152,20 +152,33 @@ def __setattr__(self: Self, attr: str, value: object) -> None: |
152 | 152 | except (AttributeError, ValueError) as error: |
153 | 153 | raise AttributeError(f"Cannot set attribute {attr} on {self.__class__.__name__}") from error |
154 | 154 |
|
155 | | - def __getitem__(self: Self, item): |
156 | | - """Used by for-loops, slicing [0:3], column-access ['id'], row-access [0], multi-column access. |
157 | | - Note: If a single item is requested, return a named tuple instead of a np.void object. |
158 | | - """ |
159 | | - |
160 | | - result = self._data.__getitem__(item) |
161 | | - |
162 | | - if isinstance(item, (list, tuple)) and (len(item) == 0 or np.array(item).dtype.type is np.bool_): |
163 | | - return self.__class__(data=result) |
164 | | - if isinstance(item, (str, list, tuple)): |
165 | | - return result |
166 | | - if isinstance(result, np.void): |
167 | | - return self.__class__(data=np.array([result])) |
168 | | - return self.__class__(data=result) |
| 155 | + @overload |
| 156 | + def __getitem__( |
| 157 | + self: Self, item: slice | int | NDArray[np.bool_] | list[bool] | NDArray[np.int_] | list[int] |
| 158 | + ) -> Self: ... |
| 159 | + |
| 160 | + @overload |
| 161 | + def __getitem__(self, item: str | NDArray[np.str_] | list[str]) -> NDArray[Any]: ... |
| 162 | + |
| 163 | + def __getitem__(self, item): |
| 164 | + if isinstance(item, slice | int): |
| 165 | + new_data = self._data[item] |
| 166 | + if new_data.shape == (): |
| 167 | + new_data = np.array([new_data]) |
| 168 | + return self.__class__(data=new_data) |
| 169 | + if isinstance(item, str): |
| 170 | + return self._data[item] |
| 171 | + if (isinstance(item, np.ndarray) and item.size == 0) or (isinstance(item, list | tuple) and len(item) == 0): |
| 172 | + return self.__class__(data=self._data[[]]) |
| 173 | + if isinstance(item, list | np.ndarray): |
| 174 | + item_array = np.array(item) |
| 175 | + if item_array.dtype == np.bool_ or np.issubdtype(item_array.dtype, np.int_): |
| 176 | + return self.__class__(data=self._data[item_array]) |
| 177 | + if np.issubdtype(item_array.dtype, np.str_): |
| 178 | + return self._data[item_array.tolist()] |
| 179 | + raise NotImplementedError( |
| 180 | + f"FancyArray[{type(item).__name__}] is not supported. Try FancyArray.data[{type(item).__name__}] instead." |
| 181 | + ) |
169 | 182 |
|
170 | 183 | def __setitem__(self: Self, key, value): |
171 | 184 | if isinstance(value, FancyArray): |
|
0 commit comments