Skip to content

Commit 9c2f8a8

Browse files
authored
Merge pull request #157 from arunsureshkumar/feat-pagination-performance
Feat pagination performance
2 parents 5d08133 + 863ab88 commit 9c2f8a8

File tree

8 files changed

+365
-123
lines changed

8 files changed

+365
-123
lines changed

graphene_mongo/converter.py

Lines changed: 162 additions & 47 deletions
Large diffs are not rendered by default.

graphene_mongo/fields.py

Lines changed: 117 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import mongoengine
88
from bson import DBRef
99
from graphene import Context
10+
from graphene.types.utils import get_type
1011
from graphene.utils.str_converters import to_snake_case
1112
from graphql import ResolveInfo
1213
from promise import Promise
@@ -15,7 +16,8 @@
1516
from graphene.types.argument import to_arguments
1617
from graphene.types.dynamic import Dynamic
1718
from graphene.types.structures import Structure
18-
from graphql_relay.connection.arrayconnection import connection_from_list_slice
19+
from graphql_relay.connection.arrayconnection import cursor_to_offset
20+
from mongoengine import QuerySet
1921

2022
from .advanced_types import (
2123
FileFieldType,
@@ -25,7 +27,8 @@
2527
)
2628
from .converter import convert_mongoengine_field, MongoEngineConversionError
2729
from .registry import get_global_registry
28-
from .utils import get_model_reference_fields, get_node_from_global_id, get_query_fields
30+
from .utils import get_model_reference_fields, get_query_fields, find_skip_and_limit, \
31+
connection_from_iterables
2932

3033

3134
class MongoengineConnectionField(ConnectionField):
@@ -64,10 +67,8 @@ def order_by(self):
6467
return self.node_type._meta.order_by
6568

6669
@property
67-
def only_fields(self):
68-
if isinstance(self.node_type._meta.only_fields, str):
69-
return self.node_type._meta.only_fields.split(",")
70-
return list()
70+
def required_fields(self):
71+
return tuple(set(self.node_type._meta.required_fields + self.node_type._meta.only_fields))
7172

7273
@property
7374
def registry(self):
@@ -118,11 +119,13 @@ def is_filterable(k):
118119
),
119120
):
120121
return False
122+
if getattr(converted, "type", None) and getattr(converted.type, "_of_type", None) and issubclass(
123+
(get_type(converted.type.of_type)), graphene.Union):
124+
return False
121125
if isinstance(converted, (graphene.List)) and issubclass(
122126
getattr(converted, "_of_type", None), graphene.Union
123127
):
124128
return False
125-
126129
return True
127130

128131
def get_filter_type(_type):
@@ -177,29 +180,35 @@ def get_reference_field(r, kv):
177180
if callable(getattr(field, "get_type", None)):
178181
_type = field.get_type()
179182
if _type:
180-
node = _type._type._meta
183+
node = _type.type._meta if hasattr(_type.type, "_meta") else _type.type._of_type._meta
181184
if "id" in node.fields and not issubclass(
182185
node.model, (mongoengine.EmbeddedDocument,)
183186
):
184187
r.update({kv[0]: node.fields["id"]._type.of_type()})
188+
185189
return r
186190

187191
return reduce(get_reference_field, self.fields.items(), {})
188192

189193
@property
190194
def fields(self):
195+
self._type = get_type(self._type)
191196
return self._type._meta.fields
192197

193-
def get_queryset(self, model, info, only_fields=list(), **args):
198+
def get_queryset(self, model, info, required_fields=list(), skip=None, limit=None, reversed=False, **args):
194199
if args:
195200
reference_fields = get_model_reference_fields(self.model)
196201
hydrated_references = {}
197202
for arg_name, arg in args.copy().items():
198-
if arg_name in reference_fields:
199-
reference_obj = get_node_from_global_id(
200-
reference_fields[arg_name], info, args.pop(arg_name)
201-
)
203+
if arg_name in reference_fields and not isinstance(arg,
204+
mongoengine.base.metaclasses.TopLevelDocumentMetaclass):
205+
try:
206+
reference_obj = reference_fields[arg_name].document_type(pk=from_global_id(arg)[1])
207+
except TypeError:
208+
reference_obj = reference_fields[arg_name].document_type(pk=arg)
202209
hydrated_references[arg_name] = reference_obj
210+
elif arg_name == "id":
211+
hydrated_references["id"] = from_global_id(args.pop("id", None))[1]
203212
args.update(hydrated_references)
204213

205214
if self._get_queryset:
@@ -208,72 +217,120 @@ def get_queryset(self, model, info, only_fields=list(), **args):
208217
return queryset_or_filters
209218
else:
210219
args.update(queryset_or_filters)
220+
if limit is not None:
221+
if reversed:
222+
order_by = ""
223+
if self.order_by:
224+
order_by = self.order_by + ",-pk"
225+
else:
226+
order_by = "-pk"
227+
return model.objects(**args).no_dereference().only(*required_fields).order_by(order_by).skip(
228+
skip if skip else 0).limit(limit)
229+
else:
230+
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by).skip(
231+
skip if skip else 0).limit(limit)
232+
elif skip is not None:
233+
if reversed:
234+
order_by = ""
235+
if self.order_by:
236+
order_by = self.order_by + ",-pk"
237+
else:
238+
order_by = "-pk"
239+
return model.objects(**args).no_dereference().only(*required_fields).order_by(order_by).skip(
240+
skip)
241+
else:
242+
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by).skip(
243+
skip)
244+
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by)
211245

212-
return model.objects(**args).no_dereference().only(*only_fields).order_by(self.order_by)
213-
214-
def default_resolver(self, _root, info, only_fields=list(), **args):
246+
def default_resolver(self, _root, info, required_fields=list(), **args):
215247
args = args or {}
216248

217249
if _root is not None:
218250
field_name = to_snake_case(info.field_name)
219-
if getattr(_root, field_name, []) is not None:
220-
args["pk__in"] = [r.id for r in getattr(_root, field_name, [])]
221-
222-
connection_args = {
223-
"first": args.pop("first", None),
224-
"last": args.pop("last", None),
225-
"before": args.pop("before", None),
226-
"after": args.pop("after", None),
227-
}
251+
if field_name in _root._fields_ordered:
252+
if getattr(_root, field_name, []) is not None:
253+
args["pk__in"] = [r.id for r in getattr(_root, field_name, [])]
228254

229255
_id = args.pop('id', None)
230256

231257
if _id is not None:
232258
args['pk'] = from_global_id(_id)[-1]
233-
259+
iterables = []
260+
list_length = 0
261+
skip = 0
262+
count = 0
263+
limit = None
264+
reverse = False
234265
if callable(getattr(self.model, "objects", None)):
235-
iterables = self.get_queryset(self.model, info, only_fields, **args)
236-
if isinstance(info, ResolveInfo):
237-
if not info.context:
238-
info.context = Context()
239-
info.context.queryset = iterables
240-
list_length = iterables.count()
241-
else:
242-
iterables = []
243-
list_length = 0
244-
245-
connection = connection_from_list_slice(
246-
list_slice=iterables,
247-
args=connection_args,
248-
list_length=list_length,
249-
list_slice_length=list_length,
250-
connection_type=self.type,
251-
edge_type=self.type.Edge,
252-
pageinfo_type=graphene.PageInfo,
253-
)
266+
first = args.pop("first", None)
267+
after = cursor_to_offset(args.pop("after", None))
268+
last = args.pop("last", None)
269+
before = cursor_to_offset(args.pop("before", None))
270+
if "pk__in" in args and args["pk__in"]:
271+
count = len(args["pk__in"])
272+
skip, limit, reverse = find_skip_and_limit(first=first, last=last, after=after, before=before,
273+
count=count)
274+
if limit:
275+
if reverse:
276+
args["pk__in"] = args["pk__in"][::-1][skip:skip + limit]
277+
else:
278+
args["pk__in"] = args["pk__in"][skip:skip + limit]
279+
elif skip:
280+
args["pk__in"] = args["pk__in"][skip:]
281+
iterables = self.get_queryset(self.model, info, required_fields, **args)
282+
list_length = len(iterables)
283+
if isinstance(info, ResolveInfo):
284+
if not info.context:
285+
info.context = Context()
286+
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)
287+
elif _root is None:
288+
count = self.get_queryset(self.model, info, required_fields, **args).count()
289+
if count != 0:
290+
skip, limit, reverse = find_skip_and_limit(first=first, after=after, last=last, before=before,
291+
count=count)
292+
iterables = self.get_queryset(self.model, info, required_fields, skip, limit, reverse, **args)
293+
list_length = len(iterables)
294+
if isinstance(info, ResolveInfo):
295+
if not info.context:
296+
info.context = Context()
297+
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)
298+
has_next_page = True if (0 if limit is None else limit) + (0 if skip is None else skip) < count else False
299+
has_previous_page = True if skip else False
300+
if reverse:
301+
iterables = list(iterables)
302+
iterables.reverse()
303+
skip = limit
304+
connection = connection_from_iterables(edges=iterables, start_offset=skip,
305+
has_previous_page=has_previous_page,
306+
has_next_page=has_next_page,
307+
connection_type=self.type,
308+
edge_type=self.type.Edge,
309+
pageinfo_type=graphene.PageInfo)
310+
254311
connection.iterable = iterables
255312
connection.list_length = list_length
256313
return connection
257314

258315
def chained_resolver(self, resolver, is_partial, root, info, **args):
259-
only_fields = list()
260-
for field in self.only_fields:
316+
required_fields = list()
317+
for field in self.required_fields:
261318
if field in self.model._fields_ordered:
262-
only_fields.append(field)
319+
required_fields.append(field)
263320
for field in get_query_fields(info):
264321
if to_snake_case(field) in self.model._fields_ordered:
265-
only_fields.append(to_snake_case(field))
322+
required_fields.append(to_snake_case(field))
266323
if not bool(args) or not is_partial:
267324
if isinstance(self.model, mongoengine.Document) or isinstance(self.model,
268325
mongoengine.base.metaclasses.TopLevelDocumentMetaclass):
269326
args_copy = args.copy()
270327
for arg_name, arg in args.copy().items():
271-
if arg_name not in self.model._fields_ordered:
328+
if arg_name not in self.model._fields_ordered + tuple(self.filter_args.keys()):
272329
args_copy.pop(arg_name)
273330
if isinstance(info, ResolveInfo):
274331
if not info.context:
275332
info.context = Context()
276-
info.context.queryset = self.get_queryset(self.model, info, only_fields, **args_copy)
333+
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args_copy)
277334
# XXX: Filter nested args
278335
resolved = resolver(root, info, **args)
279336
if resolved is not None:
@@ -282,9 +339,17 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
282339
return resolved
283340
elif not isinstance(resolved[0], DBRef):
284341
return resolved
342+
elif isinstance(resolved, QuerySet):
343+
args.update(resolved._query)
344+
args_copy = args.copy()
345+
for arg_name, arg in args.copy().items():
346+
if arg_name not in self.model._fields_ordered + ('first', 'last', 'before', 'after') + tuple(
347+
self.filter_args.keys()):
348+
args_copy.pop(arg_name)
349+
return self.default_resolver(root, info, required_fields, **args_copy)
285350
else:
286351
return resolved
287-
return self.default_resolver(root, info, only_fields, **args)
352+
return self.default_resolver(root, info, required_fields, **args)
288353

289354
@classmethod
290355
def connection_resolver(cls, resolver, connection_type, root, info, **args):

graphene_mongo/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
class Registry(object):
22
def __init__(self):
33
self._registry = {}
4+
self._registry_string_map = {}
45

56
def register(self, cls):
67
from .types import MongoengineObjectType
@@ -12,6 +13,7 @@ def register(self, cls):
1213
)
1314
assert cls._meta.registry == self, "Registry for a Model have to match."
1415
self._registry[cls._meta.model] = cls
16+
self._registry_string_map[cls.__name__] = cls._meta.model.__name__
1517

1618
# Rescan all fields
1719
for model, cls in self._registry.items():

0 commit comments

Comments
 (0)