Skip to content

Commit a93c58e

Browse files
committed
fix(ConnectionField): Priorities pk__in filter for finding count of nested connection field
1 parent 75938db commit a93c58e

File tree

2 files changed

+37
-39
lines changed

2 files changed

+37
-39
lines changed

graphene_mongo/fields.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,24 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
424424
list_length = len(iterables)
425425

426426
elif callable(getattr(self.model, "objects", None)):
427-
if (
427+
if "pk__in" in args and args["pk__in"]:
428+
count = len(args["pk__in"])
429+
skip, limit = find_skip_and_limit(
430+
first=first, last=last, after=after, before=before, count=count
431+
)
432+
if limit:
433+
args["pk__in"] = args["pk__in"][skip : skip + limit]
434+
elif skip:
435+
args["pk__in"] = args["pk__in"][skip:]
436+
iterables = self.get_queryset(self.model, info, required_fields, **args)
437+
list_length = len(iterables)
438+
if isinstance(info, GraphQLResolveInfo):
439+
if not info.context:
440+
info = info._replace(context=Context())
441+
info.context.queryset = self.get_queryset(
442+
self.model, info, required_fields, **args
443+
)
444+
elif (
428445
_root is None
429446
or args
430447
or isinstance(getattr(_root, field_name, []), MongoengineConnectionField)
@@ -486,24 +503,6 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
486503
self.model, info, required_fields, **args
487504
)
488505

489-
elif "pk__in" in args and args["pk__in"]:
490-
count = len(args["pk__in"])
491-
skip, limit = find_skip_and_limit(
492-
first=first, last=last, after=after, before=before, count=count
493-
)
494-
if limit:
495-
args["pk__in"] = args["pk__in"][skip : skip + limit]
496-
elif skip:
497-
args["pk__in"] = args["pk__in"][skip:]
498-
iterables = self.get_queryset(self.model, info, required_fields, **args)
499-
list_length = len(iterables)
500-
if isinstance(info, GraphQLResolveInfo):
501-
if not info.context:
502-
info = info._replace(context=Context())
503-
info.context.queryset = self.get_queryset(
504-
self.model, info, required_fields, **args
505-
)
506-
507506
elif _root is not None:
508507
field_name = to_snake_case(info.field_name)
509508
items = getattr(_root, field_name, [])

graphene_mongo/fields_async.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,25 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
150150
list_length = len(iterables)
151151

152152
elif callable(getattr(self.model, "objects", None)):
153-
if (
153+
if "pk__in" in args and args["pk__in"]:
154+
count = len(args["pk__in"])
155+
skip, limit = find_skip_and_limit(
156+
first=first, last=last, after=after, before=before, count=count
157+
)
158+
if limit:
159+
args["pk__in"] = args["pk__in"][skip : skip + limit]
160+
elif skip:
161+
args["pk__in"] = args["pk__in"][skip:]
162+
iterables = self.get_queryset(self.model, info, required_fields, **args)
163+
iterables = await sync_to_async(list)(iterables)
164+
list_length = len(iterables)
165+
if isinstance(info, GraphQLResolveInfo):
166+
if not info.context:
167+
info = info._replace(context=Context())
168+
info.context.queryset = self.get_queryset(
169+
self.model, info, required_fields, **args
170+
)
171+
elif (
154172
_root is None
155173
or args
156174
or isinstance(getattr(_root, field_name, []), AsyncMongoengineConnectionField)
@@ -206,25 +224,6 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
206224
self.model, info, required_fields, **args
207225
)
208226

209-
elif "pk__in" in args and args["pk__in"]:
210-
count = len(args["pk__in"])
211-
skip, limit = find_skip_and_limit(
212-
first=first, last=last, after=after, before=before, count=count
213-
)
214-
if limit:
215-
args["pk__in"] = args["pk__in"][skip : skip + limit]
216-
elif skip:
217-
args["pk__in"] = args["pk__in"][skip:]
218-
iterables = self.get_queryset(self.model, info, required_fields, **args)
219-
iterables = await sync_to_async(list)(iterables)
220-
list_length = len(iterables)
221-
if isinstance(info, GraphQLResolveInfo):
222-
if not info.context:
223-
info = info._replace(context=Context())
224-
info.context.queryset = self.get_queryset(
225-
self.model, info, required_fields, **args
226-
)
227-
228227
elif _root is not None:
229228
field_name = to_snake_case(info.field_name)
230229
items = getattr(_root, field_name, [])

0 commit comments

Comments
 (0)