Skip to content

Commit 4b1410d

Browse files
authored
Update __eq__ so it is a bit stricter (#108)
* improve FancyArray.__eq__ Signed-off-by: Thijs Baaijen <[email protected]> * piigyback Fix issue found by pyright: warning: TypeVar "Self" appears only once in generic function signature Signed-off-by: Thijs Baaijen <[email protected]> * Apply suggestion from @Thijss Signed-off-by: Thijs Baaijen <[email protected]> --------- Signed-off-by: Thijs Baaijen <[email protected]>
1 parent 6b06c3e commit 4b1410d

File tree

1 file changed

+17
-15
lines changed
  • src/power_grid_model_ds/_core/model/arrays/base

1 file changed

+17
-15
lines changed

src/power_grid_model_ds/_core/model/arrays/base/array.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,14 @@ class FancyArray(ABC):
5959
_defaults: dict[str, Any] = {}
6060
_str_lengths: dict[str, int] = {}
6161

62-
def __init__(self: Self, *args, data: NDArray | None = None, **kwargs):
62+
def __init__(self, *args, data: NDArray | None = None, **kwargs):
6363
if data is None:
6464
self._data = build_array(*args, dtype=self.get_dtype(), defaults=self.get_defaults(), **kwargs)
6565
else:
6666
self._data = data
6767

6868
@property
69-
def data(self: Self) -> NDArray:
69+
def data(self) -> NDArray:
7070
return self._data
7171

7272
@classmethod
@@ -110,7 +110,7 @@ def get_dtype(cls):
110110
dtype_list.append((name, dtype))
111111
return np.dtype(dtype_list)
112112

113-
def __repr__(self: Self) -> str:
113+
def __repr__(self) -> str:
114114
try:
115115
data = getattr(self, "data")
116116
if data.size > 3:
@@ -125,7 +125,7 @@ def __str__(self) -> str:
125125
def __len__(self) -> int:
126126
return len(self._data)
127127

128-
def __iter__(self: Self):
128+
def __iter__(self):
129129
for record in self._data:
130130
yield self.__class__(data=np.array([record]))
131131

@@ -177,16 +177,18 @@ def __contains__(self: Self, item: Self) -> bool:
177177
return item.data in self._data
178178
return False
179179

180-
def __hash__(self: Self):
180+
def __hash__(self):
181181
return hash(f"{self.__class__} {self}")
182182

183-
def __eq__(self: Self, other):
184-
return self._data.__eq__(other.data)
183+
def __eq__(self, other):
184+
if not isinstance(other, self.__class__):
185+
return False
186+
return self.data.__eq__(other.data)
185187

186-
def __copy__(self: Self):
188+
def __copy__(self):
187189
return self.__class__(data=copy(self._data))
188190

189-
def copy(self: Self):
191+
def copy(self):
190192
"""Return a copy of this array including its data"""
191193
return copy(self)
192194

@@ -281,15 +283,15 @@ def get(
281283
return self.__class__(data=apply_get(*args, array=self._data, mode_=mode_, **kwargs))
282284

283285
def filter_mask(
284-
self: Self,
286+
self,
285287
*args: int | Iterable[int] | np.ndarray,
286288
mode_: Literal["AND", "OR"] = "AND",
287289
**kwargs: Any | list[Any] | np.ndarray,
288290
) -> np.ndarray:
289291
return get_filter_mask(*args, array=self._data, mode_=mode_, **kwargs)
290292

291293
def exclude_mask(
292-
self: Self,
294+
self,
293295
*args: int | Iterable[int] | np.ndarray,
294296
mode_: Literal["AND", "OR"] = "AND",
295297
**kwargs: Any | list[Any] | np.ndarray,
@@ -299,7 +301,7 @@ def exclude_mask(
299301
def re_order(self: Self, new_order: ArrayLike, column: str = "id") -> Self:
300302
return self.__class__(data=re_order(self._data, new_order, column=column))
301303

302-
def update_by_id(self: Self, ids: ArrayLike, allow_missing: bool = False, **kwargs) -> None:
304+
def update_by_id(self, ids: ArrayLike, allow_missing: bool = False, **kwargs) -> None:
303305
try:
304306
_ = update_by_id(self._data, ids, allow_missing, **kwargs)
305307
except ValueError as error:
@@ -312,13 +314,13 @@ def get_updated_by_id(self: Self, ids: ArrayLike, allow_missing: bool = False, *
312314
except ValueError as error:
313315
raise ValueError(f"Cannot update {self.__class__.__name__}. {error}") from error
314316

315-
def check_ids(self: Self, return_duplicates: bool = False) -> NDArray | None:
317+
def check_ids(self, return_duplicates: bool = False) -> NDArray | None:
316318
return check_ids(self._data, return_duplicates=return_duplicates)
317319

318-
def as_table(self: Self, column_width: int | str = "auto", rows: int = 10) -> str:
320+
def as_table(self, column_width: int | str = "auto", rows: int = 10) -> str:
319321
return convert_array_to_string(self, column_width=column_width, rows=rows)
320322

321-
def as_df(self: Self):
323+
def as_df(self):
322324
"""Convert to pandas DataFrame"""
323325
if pandas is None:
324326
raise ImportError("pandas is not installed")

0 commit comments

Comments
 (0)