Skip to content

Commit 1cb68e0

Browse files
committed
Change version to dev and fix recursion error with related list sort
1 parent 1fe60ab commit 1cb68e0

File tree

3 files changed

+50
-26
lines changed

3 files changed

+50
-26
lines changed

atomdb/sql.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,12 @@ class RelatedList(atomclist):
197197

198198
__slots__ = ()
199199

200-
async def load(self) -> list:
200+
def sort(self, *, key=None, reverse=False):
201+
# AtomCListHandler calls super(type(self), self).sort which causes a loop
202+
# in a loop because we fiddle with the __class__
203+
super(atomclist, self).sort(key=key, reverse=reverse)
204+
205+
async def load(self, inplace: bool = True) -> list:
201206
"""Returns a list of the related values."""
202207
Model = cast(Type[SQLModel], type(owner))
203208
ThroughModel = relation.through
@@ -234,38 +239,45 @@ async def load(self) -> list:
234239
f"relation between {Model} and through model {ThroughModel}"
235240
f": Tried {Model.__backrefs__}"
236241
)
237-
return [
242+
items = [
238243
getattr(row, relation_backref.name)
239244
for row in await ThroughModel.objects.select_related(
240245
relation_backref.name
241246
).filter(**{owner_backref.name: owner})
242247
]
243-
# A many to one relation case. For example:
244-
#
245-
# class Page(SQLModel):
246-
# comments = Relation(lambda: Comment)
247-
# class Comment(SQLModel):
248-
# page = Instance(Page)
249-
#
250-
# When we have:
251-
# comments = await page.comments.load()
252-
# The page is the owner, and the comments member is the relation.
253-
# So inlining it will be the same as the following:
254-
# comments = await Comments.objects.filter(page=page)
255-
owner_backref = resolve_backref(Model, RelModel)
256-
if owner_backref is None:
257-
raise UnresolvableError(
258-
f"relation between {Model} and {RelModel}"
259-
f": Tried {Model.__backrefs__}"
260-
)
261-
return await RelModel.objects.filter(**{owner_backref.name: owner})
248+
else:
249+
# A many to one relation case. For example:
250+
#
251+
# class Page(SQLModel):
252+
# comments = Relation(lambda: Comment)
253+
# class Comment(SQLModel):
254+
# page = Instance(Page)
255+
#
256+
# When we have:
257+
# comments = await page.comments.load()
258+
# The page is the owner, and the comments member is the relation.
259+
# So inlining it will be the same as the following:
260+
# comments = await Comments.objects.filter(page=page)
261+
owner_backref = resolve_backref(Model, RelModel)
262+
if owner_backref is None:
263+
raise UnresolvableError(
264+
f"relation between {Model} and {RelModel}"
265+
f": Tried {Model.__backrefs__}"
266+
)
267+
items = await RelModel.objects.filter(**{owner_backref.name: owner})
268+
269+
if inplace:
270+
for item in items:
271+
if item not in self:
272+
self.append(item)
273+
return items
262274

263275
async def save(self, connection=None):
264276
"""Save the current list as the complete set of related items. This
265277
should only be used for small sets of items.
266278
"""
267279
current = set(self)
268-
saved = set(await self.load())
280+
saved = set(await self.load(inplace=False))
269281
ThroughModel = relation.through
270282
RelModel = relation.to
271283
if ThroughModel is not None:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
setup(
1515
name="atom-db",
16-
version="0.8.1",
16+
version="0.8.1.dev",
1717
author="CodeLV",
1818
author_email="[email protected]",
1919
url="https://github.com/codelv/atom-db",

tests/test_sql.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -552,12 +552,14 @@ async def test_query_related(db):
552552
job2 = await Job.objects.create(name="Manager")
553553

554554
await JobRole.objects.create(job=job, name="Cooking")
555+
await JobRole.objects.create(job=job, name="Grilling")
556+
555557
await JobRole.objects.create(job=job1, name="Serving")
556558
role2 = await JobRole.objects.create(job=job2, name="Managing")
557559

558560
roles = await JobRole.objects.filter(job__name__in=[job.name, job2.name])
559-
assert len(roles) == 2
560-
assert await JobRole.objects.count(job__name__in=[job.name, job2.name]) == 2
561+
assert len(roles) == 3
562+
assert await JobRole.objects.count(job__name__in=[job.name, job2.name]) == 3
561563

562564
roles = await JobRole.objects.filter(job__name=job2.name)
563565
assert len(roles) == 1
@@ -570,7 +572,17 @@ async def test_query_related(db):
570572
assert len(roles) == 1 and roles[0] == role2
571573

572574
roles = await JobRole.objects.filter(job__name__not="none of the above")
573-
assert len(roles) == 3
575+
assert len(roles) == 4
576+
577+
# Test related list
578+
assert len(job.roles) == 0 # Not loaded
579+
await job.roles.load()
580+
assert len(job.roles) == 2
581+
job.roles.append(JobRole(name="Baking", job=job))
582+
job.roles.sort(key=lambda it: it.name)
583+
assert [it.name for it in job.roles] == ["Baking", "Cooking", "Grilling"]
584+
await job.roles.save()
585+
assert await JobRole.objects.filter(job__name=job.name).count() == 3
574586

575587
# Cant do multiple joins
576588
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)