diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index e6244b78a..3881a6888 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -235,6 +235,27 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 ): return orig == new + if hasattr(orig, "__attrs_attrs__") and hasattr(new, "__attrs_attrs__"): + orig_dict = {} + new_dict = {} + + for attr in orig.__attrs_attrs__: + if attr.eq: + attr_name = attr.name + orig_dict[attr_name] = getattr(orig, attr_name, None) + new_dict[attr_name] = getattr(new, attr_name, None) + + if superset_obj: + new_attrs_dict = {} + for attr in new.__attrs_attrs__: + if attr.eq: + attr_name = attr.name + new_attrs_dict[attr_name] = getattr(new, attr_name, None) + return all( + k in new_attrs_dict and comparator(v, new_attrs_dict[k], superset_obj) for k, v in orig_dict.items() + ) + return comparator(orig_dict, new_dict, superset_obj) + # re.Pattern can be made better by DFA Minimization and then comparing if isinstance( orig, (datetime.datetime, datetime.date, datetime.timedelta, datetime.time, datetime.timezone, re.Pattern) diff --git a/tests/test_comparator.py b/tests/test_comparator.py index 4e404a512..f3f14b86c 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -1502,4 +1502,147 @@ def test_collections() -> None: d = "hello" assert comparator(a, b) assert not comparator(a, c) - assert not comparator(a, d) \ No newline at end of file + assert not comparator(a, d) + + +def test_attrs(): + try: + import attrs # type: ignore + except ImportError: + pytest.skip() + + @attrs.define + class Person: + name: str + age: int = 10 + + a = Person("Alice", 25) + b = Person("Alice", 25) + c = Person("Bob", 25) + d = Person("Alice", 30) + assert comparator(a, b) + assert not comparator(a, c) + assert not comparator(a, d) + + @attrs.frozen + class Point: + x: int + y: int + + p1 = Point(1, 2) + p2 = Point(1, 2) + p3 = Point(2, 3) + assert comparator(p1, p2) + assert not comparator(p1, p3) + + @attrs.define(slots=True) + class Vehicle: + brand: str + model: str + year: int = 2020 + + v1 = Vehicle("Toyota", "Camry", 2021) + v2 = Vehicle("Toyota", "Camry", 2021) + v3 = Vehicle("Honda", "Civic", 2021) + assert comparator(v1, v2) + assert not comparator(v1, v3) + + @attrs.define + class ComplexClass: + public_field: str + private_field: str = attrs.field(repr=False) + non_eq_field: int = attrs.field(eq=False, default=0) + computed: str = attrs.field(init=False, eq=True) + + def __attrs_post_init__(self): + self.computed = f"{self.public_field}_{self.private_field}" + + c1 = ComplexClass("test", "secret") + c2 = ComplexClass("test", "secret") + c3 = ComplexClass("different", "secret") + + c1.non_eq_field = 100 + c2.non_eq_field = 200 + + assert comparator(c1, c2) + assert not comparator(c1, c3) + + @attrs.define + class Address: + street: str + city: str + + @attrs.define + class PersonWithAddress: + name: str + address: Address + + addr1 = Address("123 Main St", "Anytown") + addr2 = Address("123 Main St", "Anytown") + addr3 = Address("456 Oak Ave", "Anytown") + + person1 = PersonWithAddress("John", addr1) + person2 = PersonWithAddress("John", addr2) + person3 = PersonWithAddress("John", addr3) + + assert comparator(person1, person2) + assert not comparator(person1, person3) + + @attrs.define + class Container: + items: list + metadata: dict + + cont1 = Container([1, 2, 3], {"type": "numbers"}) + cont2 = Container([1, 2, 3], {"type": "numbers"}) + cont3 = Container([1, 2, 4], {"type": "numbers"}) + + assert comparator(cont1, cont2) + assert not comparator(cont1, cont3) + + @attrs.define + class BaseClass: + name: str + value: int + + @attrs.define + class ExtendedClass: + name: str + value: int + extra_field: str = "default" + + base = BaseClass("test", 42) + extended = ExtendedClass("test", 42, "extra") + + assert not comparator(base, extended) + + @attrs.define + class WithNonEqFields: + name: str + timestamp: float = attrs.field(eq=False) # Should be ignored + debug_info: str = attrs.field(eq=False, default="debug") + + obj1 = WithNonEqFields("test", 1000.0, "info1") + obj2 = WithNonEqFields("test", 9999.0, "info2") # Different non-eq fields + obj3 = WithNonEqFields("different", 1000.0, "info1") + + assert comparator(obj1, obj2) # Should be equal despite different timestamp/debug_info + assert not comparator(obj1, obj3) # Should be different due to name + @attrs.define + class MinimalClass: + name: str + value: int + + @attrs.define + class ExtendedClass: + name: str + value: int + extra_field: str = "default" + metadata: dict = attrs.field(factory=dict) + timestamp: float = attrs.field(eq=False, default=0.0) # This should be ignored + + minimal = MinimalClass("test", 42) + extended = ExtendedClass("test", 42, "extra", {"key": "value"}, 1000.0) + + assert not comparator(minimal, extended) + \ No newline at end of file