Skip to content

Commit 5dffd2a

Browse files
authored
Merge pull request #757 from codeflash-ai/attrs-support
support attrs comparison
2 parents 32e85ee + febe0ec commit 5dffd2a

File tree

2 files changed

+165
-1
lines changed

2 files changed

+165
-1
lines changed

codeflash/verification/comparator.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,27 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
235235
):
236236
return orig == new
237237

238+
if hasattr(orig, "__attrs_attrs__") and hasattr(new, "__attrs_attrs__"):
239+
orig_dict = {}
240+
new_dict = {}
241+
242+
for attr in orig.__attrs_attrs__:
243+
if attr.eq:
244+
attr_name = attr.name
245+
orig_dict[attr_name] = getattr(orig, attr_name, None)
246+
new_dict[attr_name] = getattr(new, attr_name, None)
247+
248+
if superset_obj:
249+
new_attrs_dict = {}
250+
for attr in new.__attrs_attrs__:
251+
if attr.eq:
252+
attr_name = attr.name
253+
new_attrs_dict[attr_name] = getattr(new, attr_name, None)
254+
return all(
255+
k in new_attrs_dict and comparator(v, new_attrs_dict[k], superset_obj) for k, v in orig_dict.items()
256+
)
257+
return comparator(orig_dict, new_dict, superset_obj)
258+
238259
# re.Pattern can be made better by DFA Minimization and then comparing
239260
if isinstance(
240261
orig, (datetime.datetime, datetime.date, datetime.timedelta, datetime.time, datetime.timezone, re.Pattern)

tests/test_comparator.py

Lines changed: 144 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1502,4 +1502,147 @@ def test_collections() -> None:
15021502
d = "hello"
15031503
assert comparator(a, b)
15041504
assert not comparator(a, c)
1505-
assert not comparator(a, d)
1505+
assert not comparator(a, d)
1506+
1507+
1508+
def test_attrs():
1509+
try:
1510+
import attrs # type: ignore
1511+
except ImportError:
1512+
pytest.skip()
1513+
1514+
@attrs.define
1515+
class Person:
1516+
name: str
1517+
age: int = 10
1518+
1519+
a = Person("Alice", 25)
1520+
b = Person("Alice", 25)
1521+
c = Person("Bob", 25)
1522+
d = Person("Alice", 30)
1523+
assert comparator(a, b)
1524+
assert not comparator(a, c)
1525+
assert not comparator(a, d)
1526+
1527+
@attrs.frozen
1528+
class Point:
1529+
x: int
1530+
y: int
1531+
1532+
p1 = Point(1, 2)
1533+
p2 = Point(1, 2)
1534+
p3 = Point(2, 3)
1535+
assert comparator(p1, p2)
1536+
assert not comparator(p1, p3)
1537+
1538+
@attrs.define(slots=True)
1539+
class Vehicle:
1540+
brand: str
1541+
model: str
1542+
year: int = 2020
1543+
1544+
v1 = Vehicle("Toyota", "Camry", 2021)
1545+
v2 = Vehicle("Toyota", "Camry", 2021)
1546+
v3 = Vehicle("Honda", "Civic", 2021)
1547+
assert comparator(v1, v2)
1548+
assert not comparator(v1, v3)
1549+
1550+
@attrs.define
1551+
class ComplexClass:
1552+
public_field: str
1553+
private_field: str = attrs.field(repr=False)
1554+
non_eq_field: int = attrs.field(eq=False, default=0)
1555+
computed: str = attrs.field(init=False, eq=True)
1556+
1557+
def __attrs_post_init__(self):
1558+
self.computed = f"{self.public_field}_{self.private_field}"
1559+
1560+
c1 = ComplexClass("test", "secret")
1561+
c2 = ComplexClass("test", "secret")
1562+
c3 = ComplexClass("different", "secret")
1563+
1564+
c1.non_eq_field = 100
1565+
c2.non_eq_field = 200
1566+
1567+
assert comparator(c1, c2)
1568+
assert not comparator(c1, c3)
1569+
1570+
@attrs.define
1571+
class Address:
1572+
street: str
1573+
city: str
1574+
1575+
@attrs.define
1576+
class PersonWithAddress:
1577+
name: str
1578+
address: Address
1579+
1580+
addr1 = Address("123 Main St", "Anytown")
1581+
addr2 = Address("123 Main St", "Anytown")
1582+
addr3 = Address("456 Oak Ave", "Anytown")
1583+
1584+
person1 = PersonWithAddress("John", addr1)
1585+
person2 = PersonWithAddress("John", addr2)
1586+
person3 = PersonWithAddress("John", addr3)
1587+
1588+
assert comparator(person1, person2)
1589+
assert not comparator(person1, person3)
1590+
1591+
@attrs.define
1592+
class Container:
1593+
items: list
1594+
metadata: dict
1595+
1596+
cont1 = Container([1, 2, 3], {"type": "numbers"})
1597+
cont2 = Container([1, 2, 3], {"type": "numbers"})
1598+
cont3 = Container([1, 2, 4], {"type": "numbers"})
1599+
1600+
assert comparator(cont1, cont2)
1601+
assert not comparator(cont1, cont3)
1602+
1603+
@attrs.define
1604+
class BaseClass:
1605+
name: str
1606+
value: int
1607+
1608+
@attrs.define
1609+
class ExtendedClass:
1610+
name: str
1611+
value: int
1612+
extra_field: str = "default"
1613+
1614+
base = BaseClass("test", 42)
1615+
extended = ExtendedClass("test", 42, "extra")
1616+
1617+
assert not comparator(base, extended)
1618+
1619+
@attrs.define
1620+
class WithNonEqFields:
1621+
name: str
1622+
timestamp: float = attrs.field(eq=False) # Should be ignored
1623+
debug_info: str = attrs.field(eq=False, default="debug")
1624+
1625+
obj1 = WithNonEqFields("test", 1000.0, "info1")
1626+
obj2 = WithNonEqFields("test", 9999.0, "info2") # Different non-eq fields
1627+
obj3 = WithNonEqFields("different", 1000.0, "info1")
1628+
1629+
assert comparator(obj1, obj2) # Should be equal despite different timestamp/debug_info
1630+
assert not comparator(obj1, obj3) # Should be different due to name
1631+
@attrs.define
1632+
class MinimalClass:
1633+
name: str
1634+
value: int
1635+
1636+
@attrs.define
1637+
class ExtendedClass:
1638+
name: str
1639+
value: int
1640+
extra_field: str = "default"
1641+
metadata: dict = attrs.field(factory=dict)
1642+
timestamp: float = attrs.field(eq=False, default=0.0) # This should be ignored
1643+
1644+
minimal = MinimalClass("test", 42)
1645+
extended = ExtendedClass("test", 42, "extra", {"key": "value"}, 1000.0)
1646+
1647+
assert not comparator(minimal, extended)
1648+

0 commit comments

Comments
 (0)