Skip to content

Commit 9bbdb05

Browse files
aarcex3dantownsend
andauthored
refactor: add __eq__ on Table (#1098)
* refactor: add __eq__ on Table * feat: add __eq__ on Table * testing (Table): test equality * final amends * add missing delete * allow comparison with a primary key value * use `value_type` instead * add docs * fix typo in docs * show `band_1.id == band_2.id` * remove raw value comparison * fix linter errors --------- Co-authored-by: Daniel Townsend <dan@dantownsend.co.uk>
1 parent 1a18547 commit 9bbdb05

File tree

4 files changed

+189
-1
lines changed

4 files changed

+189
-1
lines changed

docs/src/piccolo/query_types/objects.rst

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,54 @@ It works with ``prefetch`` too:
361361
362362
-------------------------------------------------------------------------------
363363

364+
Comparing objects
365+
-----------------
366+
367+
If you have two objects, and you want to know whether they refer to the same
368+
row in the database, you can simply use the equality operator:
369+
370+
.. code-block:: python
371+
372+
band_1 = await Band.objects().where(Band.name == "Pythonistas").first()
373+
band_2 = await Band.objects().where(Band.name == "Pythonistas").first()
374+
375+
>>> band_1 == band_2
376+
True
377+
378+
It works by comparing the primary key value of each object. It's equivalent to
379+
this:
380+
381+
.. code-block:: python
382+
383+
>>> band_1.id == band_2.id
384+
True
385+
386+
If the object has no primary key value yet (e.g. it uses a ``Serial`` column,
387+
and it hasn't been saved in the database), then the result will always be
388+
``False``:
389+
390+
.. code-block:: python
391+
392+
band_1 = Band()
393+
band_2 = Band()
394+
395+
>>> band_1 == band_2
396+
False
397+
398+
If you want to compare every value on the objects, and not just the primary
399+
key, you can use ``to_dict``. For example:
400+
401+
.. code-block:: python
402+
403+
>>> band_1.to_dict() == band_2.to_dict()
404+
True
405+
406+
>>> band_1.popularity = 10_000
407+
>>> band_1.to_dict() == band_2.to_dict()
408+
False
409+
410+
-------------------------------------------------------------------------------
411+
364412
Query clauses
365413
-------------
366414

piccolo/columns/column_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -785,7 +785,7 @@ def column_type(self):
785785
return "INTEGER"
786786
raise Exception("Unrecognized engine type")
787787

788-
def default(self):
788+
def default(self) -> QueryString:
789789
engine_type = self._meta.engine_type
790790

791791
if engine_type == "postgres":

piccolo/table.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,72 @@ def __repr__(self) -> str:
851851
)
852852
return f"<{self.__class__.__name__}: {pk}>"
853853

854+
def __eq__(self, other: t.Any) -> bool:
855+
"""
856+
Lets us check if two ``Table`` instances represent the same row in the
857+
database, based on their primary key value::
858+
859+
band_1 = await Band.objects().where(
860+
Band.name == "Pythonistas"
861+
).first()
862+
863+
band_2 = await Band.objects().where(
864+
Band.name == "Pythonistas"
865+
).first()
866+
867+
band_3 = await Band.objects().where(
868+
Band.name == "Rustaceans"
869+
).first()
870+
871+
>>> band_1 == band_2
872+
True
873+
874+
>>> band_1 == band_3
875+
False
876+
877+
"""
878+
if not isinstance(other, Table):
879+
# This is the correct way to tell Python that this operation
880+
# isn't supported:
881+
# https://docs.python.org/3/library/constants.html#NotImplemented
882+
return NotImplemented
883+
884+
# Make sure we're comparing the same table.
885+
# There are several ways we could do this (like comparing tablename),
886+
# but this should be OK.
887+
if not isinstance(other, self.__class__):
888+
return False
889+
890+
pk = self._meta.primary_key
891+
892+
pk_value = getattr(
893+
self,
894+
pk._meta.name,
895+
)
896+
897+
other_pk_value = getattr(
898+
other,
899+
pk._meta.name,
900+
)
901+
902+
# Make sure the primary key values are of the correct type.
903+
# We need this for `Serial` columns, which have a `QueryString`
904+
# value until saved in the database. We don't want to use `==` on
905+
# two QueryString values, because QueryString has a custom `__eq__`
906+
# method which doesn't return a boolean.
907+
if isinstance(
908+
pk_value,
909+
pk.value_type,
910+
) and isinstance(
911+
other_pk_value,
912+
pk.value_type,
913+
):
914+
return pk_value == other_pk_value
915+
else:
916+
# As a fallback, even if it hasn't been saved in the database,
917+
# an object should still be equal to itself.
918+
return other is self
919+
854920
###########################################################################
855921
# Classmethods
856922

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from piccolo.columns.column_types import UUID, Varchar
2+
from piccolo.table import Table
3+
from piccolo.testing.test_case import AsyncTableTest
4+
from tests.example_apps.music.tables import Manager
5+
6+
7+
class ManagerUUID(Table):
8+
id = UUID(primary_key=True)
9+
name = Varchar()
10+
11+
12+
class TestInstanceEquality(AsyncTableTest):
13+
tables = [
14+
Manager,
15+
ManagerUUID,
16+
]
17+
18+
async def test_instance_equality(self) -> None:
19+
"""
20+
Make sure instance equality works, for tables with a `Serial` primary
21+
key.
22+
"""
23+
manager_1 = Manager(name="Guido")
24+
await manager_1.save()
25+
26+
manager_2 = Manager(name="Graydon")
27+
await manager_2.save()
28+
29+
self.assertEqual(manager_1, manager_1)
30+
self.assertNotEqual(manager_1, manager_2)
31+
32+
# Try fetching the row from the database.
33+
manager_1_from_db = (
34+
await Manager.objects().where(Manager.id == manager_1.id).first()
35+
)
36+
self.assertEqual(manager_1, manager_1_from_db)
37+
self.assertNotEqual(manager_2, manager_1_from_db)
38+
39+
# Try rows which haven't been saved yet.
40+
# They have no primary key value (because they use Serial columns
41+
# as the primary key), so they shouldn't be equal.
42+
self.assertNotEqual(Manager(), Manager())
43+
self.assertNotEqual(manager_1, Manager())
44+
45+
# Make sure an object is equal to itself, even if not saved.
46+
manager_unsaved = Manager()
47+
self.assertEqual(manager_unsaved, manager_unsaved)
48+
49+
async def test_instance_equality_uuid(self) -> None:
50+
"""
51+
Make sure instance equality works, for tables with a `UUID` primary
52+
key.
53+
"""
54+
manager_1 = ManagerUUID(name="Guido")
55+
await manager_1.save()
56+
57+
manager_2 = ManagerUUID(name="Graydon")
58+
await manager_2.save()
59+
60+
self.assertEqual(manager_1, manager_1)
61+
self.assertNotEqual(manager_1, manager_2)
62+
63+
# Try fetching the row from the database.
64+
manager_1_from_db = (
65+
await ManagerUUID.objects()
66+
.where(ManagerUUID.id == manager_1.id)
67+
.first()
68+
)
69+
self.assertEqual(manager_1, manager_1_from_db)
70+
self.assertNotEqual(manager_2, manager_1_from_db)
71+
72+
# Make sure an object is equal to itself, even if not saved.
73+
manager_unsaved = ManagerUUID()
74+
self.assertEqual(manager_unsaved, manager_unsaved)

0 commit comments

Comments
 (0)