66from collections import namedtuple
77from copy import copy
88from functools import lru_cache
9- from typing import Any , Iterable , Literal , Type , TypeVar
9+ from typing import Any , Iterable , Literal , Type , TypeVar , overload
1010
1111import numpy as np
1212from numpy .typing import ArrayLike , NDArray
@@ -59,14 +59,14 @@ class FancyArray(ABC):
5959 _defaults : dict [str , Any ] = {}
6060 _str_lengths : dict [str , int ] = {}
6161
62- def __init__ (self : Self , * args , data : NDArray | None = None , ** kwargs ):
62+ def __init__ (self , * args , data : NDArray | None = None , ** kwargs ):
6363 if data is None :
6464 self ._data = build_array (* args , dtype = self .get_dtype (), defaults = self .get_defaults (), ** kwargs )
6565 else :
6666 self ._data = data
6767
6868 @property
69- def data (self : Self ) -> NDArray :
69+ def data (self ) -> NDArray :
7070 return self ._data
7171
7272 @classmethod
@@ -110,7 +110,7 @@ def get_dtype(cls):
110110 dtype_list .append ((name , dtype ))
111111 return np .dtype (dtype_list )
112112
113- def __repr__ (self : Self ) -> str :
113+ def __repr__ (self ) -> str :
114114 try :
115115 data = getattr (self , "data" )
116116 if data .size > 3 :
@@ -125,7 +125,7 @@ def __str__(self) -> str:
125125 def __len__ (self ) -> int :
126126 return len (self ._data )
127127
128- def __iter__ (self : Self ):
128+ def __iter__ (self ):
129129 for record in self ._data :
130130 yield self .__class__ (data = np .array ([record ]))
131131
@@ -152,20 +152,33 @@ def __setattr__(self: Self, attr: str, value: object) -> None:
152152 except (AttributeError , ValueError ) as error :
153153 raise AttributeError (f"Cannot set attribute { attr } on { self .__class__ .__name__ } " ) from error
154154
155- def __getitem__ (self : Self , item ):
156- """Used by for-loops, slicing [0:3], column-access ['id'], row-access [0], multi-column access.
157- Note: If a single item is requested, return a named tuple instead of a np.void object.
158- """
159-
160- result = self ._data .__getitem__ (item )
161-
162- if isinstance (item , (list , tuple )) and (len (item ) == 0 or np .array (item ).dtype .type is np .bool_ ):
163- return self .__class__ (data = result )
164- if isinstance (item , (str , list , tuple )):
165- return result
166- if isinstance (result , np .void ):
167- return self .__class__ (data = np .array ([result ]))
168- return self .__class__ (data = result )
155+ @overload
156+ def __getitem__ (
157+ self : Self , item : slice | int | NDArray [np .bool_ ] | list [bool ] | NDArray [np .int_ ] | list [int ]
158+ ) -> Self : ...
159+
160+ @overload
161+ def __getitem__ (self , item : str | NDArray [np .str_ ] | list [str ]) -> NDArray [Any ]: ...
162+
163+ def __getitem__ (self , item ):
164+ if isinstance (item , slice | int ):
165+ new_data = self ._data [item ]
166+ if new_data .shape == ():
167+ new_data = np .array ([new_data ])
168+ return self .__class__ (data = new_data )
169+ if isinstance (item , str ):
170+ return self ._data [item ]
171+ if (isinstance (item , np .ndarray ) and item .size == 0 ) or (isinstance (item , list | tuple ) and len (item ) == 0 ):
172+ return self .__class__ (data = self ._data [[]])
173+ if isinstance (item , list | np .ndarray ):
174+ item_array = np .array (item )
175+ if item_array .dtype == np .bool_ or np .issubdtype (item_array .dtype , np .int_ ):
176+ return self .__class__ (data = self ._data [item_array ])
177+ if np .issubdtype (item_array .dtype , np .str_ ):
178+ return self ._data [item_array .tolist ()]
179+ raise NotImplementedError (
180+ f"FancyArray[{ type (item ).__name__ } ] is not supported. Try FancyArray.data[{ type (item ).__name__ } ] instead."
181+ )
169182
170183 def __setitem__ (self : Self , key , value ):
171184 if isinstance (value , FancyArray ):
@@ -177,16 +190,18 @@ def __contains__(self: Self, item: Self) -> bool:
177190 return item .data in self ._data
178191 return False
179192
180- def __hash__ (self : Self ):
193+ def __hash__ (self ):
181194 return hash (f"{ self .__class__ } { self } " )
182195
183- def __eq__ (self : Self , other ):
184- return self ._data .__eq__ (other .data )
196+ def __eq__ (self , other ):
197+ if not isinstance (other , self .__class__ ):
198+ return False
199+ return self .data .__eq__ (other .data )
185200
186- def __copy__ (self : Self ):
201+ def __copy__ (self ):
187202 return self .__class__ (data = copy (self ._data ))
188203
189- def copy (self : Self ):
204+ def copy (self ):
190205 """Return a copy of this array including its data"""
191206 return copy (self )
192207
@@ -281,15 +296,15 @@ def get(
281296 return self .__class__ (data = apply_get (* args , array = self ._data , mode_ = mode_ , ** kwargs ))
282297
283298 def filter_mask (
284- self : Self ,
299+ self ,
285300 * args : int | Iterable [int ] | np .ndarray ,
286301 mode_ : Literal ["AND" , "OR" ] = "AND" ,
287302 ** kwargs : Any | list [Any ] | np .ndarray ,
288303 ) -> np .ndarray :
289304 return get_filter_mask (* args , array = self ._data , mode_ = mode_ , ** kwargs )
290305
291306 def exclude_mask (
292- self : Self ,
307+ self ,
293308 * args : int | Iterable [int ] | np .ndarray ,
294309 mode_ : Literal ["AND" , "OR" ] = "AND" ,
295310 ** kwargs : Any | list [Any ] | np .ndarray ,
@@ -299,7 +314,7 @@ def exclude_mask(
299314 def re_order (self : Self , new_order : ArrayLike , column : str = "id" ) -> Self :
300315 return self .__class__ (data = re_order (self ._data , new_order , column = column ))
301316
302- def update_by_id (self : Self , ids : ArrayLike , allow_missing : bool = False , ** kwargs ) -> None :
317+ def update_by_id (self , ids : ArrayLike , allow_missing : bool = False , ** kwargs ) -> None :
303318 try :
304319 _ = update_by_id (self ._data , ids , allow_missing , ** kwargs )
305320 except ValueError as error :
@@ -312,13 +327,13 @@ def get_updated_by_id(self: Self, ids: ArrayLike, allow_missing: bool = False, *
312327 except ValueError as error :
313328 raise ValueError (f"Cannot update { self .__class__ .__name__ } . { error } " ) from error
314329
315- def check_ids (self : Self , return_duplicates : bool = False ) -> NDArray | None :
330+ def check_ids (self , return_duplicates : bool = False ) -> NDArray | None :
316331 return check_ids (self ._data , return_duplicates = return_duplicates )
317332
318- def as_table (self : Self , column_width : int | str = "auto" , rows : int = 10 ) -> str :
333+ def as_table (self , column_width : int | str = "auto" , rows : int = 10 ) -> str :
319334 return convert_array_to_string (self , column_width = column_width , rows = rows )
320335
321- def as_df (self : Self ):
336+ def as_df (self ):
322337 """Convert to pandas DataFrame"""
323338 if pandas is None :
324339 raise ImportError ("pandas is not installed" )
0 commit comments