Skip to content

Commit a8a17c3

Browse files
committed
feat: Add GenericRichComparison
This is to help w/ migration away from inject_richcmp_from_cmp for instances that use __attr_comparison__ logic. Signed-off-by: Brian Harring <ferringb@gmail.com>
1 parent 01299fc commit a8a17c3

File tree

2 files changed

+103
-6
lines changed

2 files changed

+103
-6
lines changed

src/snakeoil/klass/__init__.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
from functools import wraps
4343
from importlib import import_module
4444
from operator import attrgetter
45-
from typing import Any
4645

4746
from snakeoil.deprecation import deprecated as warn_deprecated
4847

@@ -167,14 +166,22 @@ def __attr_comparison__(self) -> tuple[str, ...]: # pyright: ignore[reportRedec
167166

168167
__attr_comparison__: tuple[str, ...]
169168

170-
def __eq__(self, other: Any) -> bool:
169+
def __eq__(
170+
self, value, /, attr_comparison_override: tuple[str, ...] | None = None
171+
) -> bool:
171172
"""
172-
Comparison is down via comparing attributes listed in self.__attr_comparison__
173+
Comparison is down via comparing attributes listed in self.__attr_comparison__,
174+
or via the passed in attr_comparison_override. That exists specifically to
175+
simplify subclass partial reuse of the class when logic gets complex.
173176
"""
174-
if self is other:
177+
if self is value:
175178
return True
176-
for attr in self.__attr_comparison__:
177-
if getattr(self, attr, sentinel) != getattr(other, attr, sentinel):
179+
for attr in (
180+
self.__attr_comparison__
181+
if attr_comparison_override is None
182+
else attr_comparison_override
183+
):
184+
if getattr(self, attr, sentinel) != getattr(value, attr, sentinel):
178185
return False
179186
return True
180187

@@ -195,6 +202,56 @@ def __init_subclass__(cls) -> None:
195202
return super().__init_subclass__()
196203

197204

205+
class GenericRichComparison(GenericEquality):
206+
__slots__ = ()
207+
208+
def __lt__(self, value, attr_comparison_override: tuple[str, ...] | None = None):
209+
if self is value:
210+
return False
211+
attrlist = (
212+
self.__attr_comparison__
213+
if attr_comparison_override is None
214+
else attr_comparison_override
215+
)
216+
for attr in attrlist:
217+
obj1, obj2 = getattr(self, attr, sentinel), getattr(value, attr, sentinel)
218+
if obj1 is sentinel:
219+
if obj2 is sentinel:
220+
continue
221+
return True
222+
elif obj2 is sentinel:
223+
return False
224+
if not (obj1 >= obj2): # pyright: ignore[reportOperatorIssue]
225+
return True
226+
return False
227+
228+
def __le__(self, value, attr_comparison_override: tuple[str, ...] | None = None):
229+
if self is value:
230+
return True
231+
attrlist = (
232+
self.__attr_comparison__
233+
if attr_comparison_override is None
234+
else attr_comparison_override
235+
)
236+
for attr in attrlist:
237+
obj1, obj2 = getattr(self, attr, sentinel), getattr(value, attr, sentinel)
238+
if obj1 is sentinel:
239+
if obj2 is sentinel:
240+
continue
241+
return True
242+
elif obj2 is sentinel:
243+
return False
244+
if not (obj1 > obj2): # pyright: ignore[reportOperatorIssue]
245+
return True
246+
return False
247+
248+
def __gt__(self, value, attr_comparison_override: tuple[str, ...] | None = None):
249+
return not self.__le__(value, attr_comparison_override=attr_comparison_override)
250+
251+
def __ge__(self, value, attr_comparison_override: tuple[str, ...] | None = None):
252+
return not self.__lt__(value, attr_comparison_override=attr_comparison_override)
253+
254+
198255
@warn_deprecated(
199256
"generic_equality metaclass usage is deprecated; inherit from snakeoil.klass.GenericEquality instead."
200257
)

tests/klass/test_init.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,3 +661,43 @@ def __eq__(self, other):
661661
subclass_disabling_must_be_allowed.__annotations__["__attr_comparison__"]
662662
is None
663663
), "annotations weren't updated"
664+
665+
666+
class TestGenericRichComparison:
667+
def test_it(self):
668+
class kls(klass.GenericRichComparison):
669+
__attr_comparison__ = ("x", "y")
670+
671+
def __init__(self, x=1, y=2):
672+
self.x, self.y = x, y
673+
674+
obj1 = kls()
675+
# these use identity check shortcuts, validate how it handles
676+
# the same instance.
677+
assert obj1 == obj1
678+
assert obj1 <= obj1
679+
assert obj1 >= obj1
680+
assert not (obj1 < obj1)
681+
assert not (obj1 > obj1)
682+
assert not (obj1 != obj1)
683+
684+
# validate the usual scenarios
685+
obj2 = kls()
686+
assert not (obj1 < obj2)
687+
assert obj1 <= obj2
688+
assert not (obj1 != obj2)
689+
assert obj1 >= obj2
690+
assert not (obj1 > obj2)
691+
692+
obj1.x = 0
693+
assert obj1 < obj2
694+
assert obj1 <= obj2
695+
assert not (obj1 > obj2)
696+
assert not (obj1 >= obj2)
697+
698+
del obj1.x
699+
assert obj1 != obj2
700+
assert obj1 < obj2
701+
assert obj1 <= obj2
702+
assert not (obj1 > obj2)
703+
assert not (obj1 >= obj2)

0 commit comments

Comments
 (0)