Skip to content

Commit 94b467f

Browse files
Merge branch 'feat-pagination-performance'
# Conflicts: # graphene_mongo/fields.py
2 parents 751b6b4 + 1ff79aa commit 94b467f

File tree

1 file changed

+83
-31
lines changed

1 file changed

+83
-31
lines changed

graphene_mongo/fields.py

Lines changed: 83 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from graphene.types.argument import to_arguments
1717
from graphene.types.dynamic import Dynamic
1818
from graphene.types.structures import Structure
19-
from graphql_relay.connection.arrayconnection import connection_from_list_slice
19+
from graphql_relay.connection.arrayconnection import cursor_to_offset
20+
from mongoengine import QuerySet
2021

2122
from .advanced_types import (
2223
FileFieldType,
@@ -26,7 +27,8 @@
2627
)
2728
from .converter import convert_mongoengine_field, MongoEngineConversionError
2829
from .registry import get_global_registry
29-
from .utils import get_model_reference_fields, get_node_from_global_id, get_query_fields
30+
from .utils import get_model_reference_fields, get_node_from_global_id, get_query_fields, find_skip_and_limit, \
31+
connection_from_iterables
3032

3133

3234
class MongoengineConnectionField(ConnectionField):
@@ -190,7 +192,7 @@ def fields(self):
190192
self._type = get_type(self._type)
191193
return self._type._meta.fields
192194

193-
def get_queryset(self, model, info, required_fields=list(), **args):
195+
def get_queryset(self, model, info, required_fields=list(), skip=None, limit=None, reversed=False, **args):
194196
if args:
195197
reference_fields = get_model_reference_fields(self.model)
196198
hydrated_references = {}
@@ -208,7 +210,30 @@ def get_queryset(self, model, info, required_fields=list(), **args):
208210
return queryset_or_filters
209211
else:
210212
args.update(queryset_or_filters)
211-
213+
if limit is not None:
214+
if reversed:
215+
order_by = ""
216+
if self.order_by:
217+
order_by = self.order_by + ",-pk"
218+
else:
219+
order_by = "-pk"
220+
return model.objects(**args).no_dereference().only(*required_fields).order_by(order_by).skip(
221+
skip if skip else 0).limit(limit)
222+
else:
223+
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by).skip(
224+
skip if skip else 0).limit(limit)
225+
elif skip is not None:
226+
if reversed:
227+
order_by = ""
228+
if self.order_by:
229+
order_by = self.order_by + ",-pk"
230+
else:
231+
order_by = "-pk"
232+
return model.objects(**args).no_dereference().only(*required_fields).order_by(order_by).skip(
233+
skip)
234+
else:
235+
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by).skip(
236+
skip)
212237
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by)
213238

214239
def default_resolver(self, _root, info, required_fields=list(), **args):
@@ -220,38 +245,62 @@ def default_resolver(self, _root, info, required_fields=list(), **args):
220245
if getattr(_root, field_name, []) is not None:
221246
args["pk__in"] = [r.id for r in getattr(_root, field_name, [])]
222247

223-
connection_args = {
224-
"first": args.pop("first", None),
225-
"last": args.pop("last", None),
226-
"before": args.pop("before", None),
227-
"after": args.pop("after", None),
228-
}
229-
230248
_id = args.pop('id', None)
231249

232250
if _id is not None:
233251
args['pk'] = from_global_id(_id)[-1]
234-
252+
iterables = []
253+
list_length = 0
254+
skip = 0
255+
count = 0
256+
limit = None
257+
reverse = False
235258
if callable(getattr(self.model, "objects", None)):
236-
iterables = self.get_queryset(self.model, info, required_fields, **args)
237-
if isinstance(info, ResolveInfo):
238-
if not info.context:
239-
info.context = Context()
240-
info.context.queryset = iterables
241-
list_length = iterables.count()
242-
else:
243-
iterables = []
244-
list_length = 0
245-
246-
connection = connection_from_list_slice(
247-
list_slice=iterables,
248-
args=connection_args,
249-
list_length=list_length,
250-
list_slice_length=list_length,
251-
connection_type=self.type,
252-
edge_type=self.type.Edge,
253-
pageinfo_type=graphene.PageInfo,
254-
)
259+
first = args.pop("first", None)
260+
after = cursor_to_offset(args.pop("after", None))
261+
last = args.pop("last", None)
262+
before = cursor_to_offset(args.pop("before", None))
263+
if "pk__in" in args and args["pk__in"]:
264+
count = len(args["pk__in"])
265+
skip, limit, reverse = find_skip_and_limit(first=first, last=last, after=after, before=before,
266+
count=count)
267+
if limit:
268+
if reverse:
269+
args["pk__in"] = args["pk__in"][::-1][skip:skip + limit]
270+
else:
271+
args["pk__in"] = args["pk__in"][skip:skip + limit]
272+
elif skip:
273+
args["pk__in"] = args["pk__in"][skip:]
274+
iterables = self.get_queryset(self.model, info, required_fields, **args)
275+
list_length = len(iterables)
276+
if isinstance(info, ResolveInfo):
277+
if not info.context:
278+
info.context = Context()
279+
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)
280+
elif _root is None:
281+
count = self.get_queryset(self.model, info, required_fields, **args).count()
282+
if count != 0:
283+
skip, limit, reverse = find_skip_and_limit(first=first, after=after, last=last, before=before,
284+
count=count)
285+
iterables = self.get_queryset(self.model, info, required_fields, skip, limit, reverse, **args)
286+
list_length = len(iterables)
287+
if isinstance(info, ResolveInfo):
288+
if not info.context:
289+
info.context = Context()
290+
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)
291+
has_next_page = True if (0 if limit is None else limit) + (0 if skip is None else skip) < count else False
292+
has_previous_page = True if skip else False
293+
if reverse:
294+
iterables = list(iterables)
295+
iterables.reverse()
296+
skip = limit
297+
connection = connection_from_iterables(edges=iterables, start_offset=skip,
298+
has_previous_page=has_previous_page,
299+
has_next_page=has_next_page,
300+
connection_type=self.type,
301+
edge_type=self.type.Edge,
302+
pageinfo_type=graphene.PageInfo)
303+
255304
connection.iterable = iterables
256305
connection.list_length = list_length
257306
return connection
@@ -283,6 +332,9 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
283332
return resolved
284333
elif not isinstance(resolved[0], DBRef):
285334
return resolved
335+
elif isinstance(resolved, QuerySet):
336+
args.update(resolved._query)
337+
return self.default_resolver(root, info, required_fields, **args)
286338
else:
287339
return resolved
288340
return self.default_resolver(root, info, required_fields, **args)

0 commit comments

Comments
 (0)