Skip to content

Commit 75b3012

Browse files
committed
Merge remote-tracking branch 'upstream/master' into merge/upstream_master
2 parents 9050760 + d1b8ace commit 75b3012

File tree

8 files changed

+502
-243
lines changed

8 files changed

+502
-243
lines changed

graphene_mongo/converter.py

Lines changed: 162 additions & 47 deletions
Large diffs are not rendered by default.

graphene_mongo/fields.py

Lines changed: 136 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import mongoengine
88
from bson import DBRef
99
from graphene import Context
10+
from graphene.types.utils import get_type
1011
from graphene.utils.str_converters import to_snake_case
1112
from graphql import ResolveInfo
1213
from promise import Promise
@@ -15,7 +16,8 @@
1516
from graphene.types.argument import to_arguments
1617
from graphene.types.dynamic import Dynamic
1718
from graphene.types.structures import Structure
18-
from graphql_relay.connection.arrayconnection import connection_from_list_slice
19+
from graphql_relay.connection.arrayconnection import cursor_to_offset
20+
from mongoengine import QuerySet
1921

2022
from .advanced_types import (
2123
FileFieldType,
@@ -25,7 +27,8 @@
2527
)
2628
from .converter import convert_mongoengine_field, MongoEngineConversionError
2729
from .registry import get_global_registry
28-
from .utils import get_model_reference_fields, get_node_from_global_id, get_query_fields
30+
from .utils import get_model_reference_fields, get_query_fields, find_skip_and_limit, \
31+
connection_from_iterables
2932

3033

3134
class MongoengineConnectionField(ConnectionField):
@@ -64,10 +67,8 @@ def order_by(self):
6467
return self.node_type._meta.order_by
6568

6669
@property
67-
def only_fields(self):
68-
if isinstance(self.node_type._meta.only_fields, str):
69-
return self.node_type._meta.only_fields.split(",")
70-
return list()
70+
def required_fields(self):
71+
return tuple(set(self.node_type._meta.required_fields + self.node_type._meta.only_fields))
7172

7273
@property
7374
def registry(self):
@@ -118,11 +119,13 @@ def is_filterable(k):
118119
),
119120
):
120121
return False
122+
if getattr(converted, "type", None) and getattr(converted.type, "_of_type", None) and issubclass(
123+
(get_type(converted.type.of_type)), graphene.Union):
124+
return False
121125
if isinstance(converted, (graphene.List)) and issubclass(
122126
getattr(converted, "_of_type", None), graphene.Union
123127
):
124128
return False
125-
126129
return True
127130

128131
def get_filter_type(_type):
@@ -177,29 +180,35 @@ def get_reference_field(r, kv):
177180
if callable(getattr(field, "get_type", None)):
178181
_type = field.get_type()
179182
if _type:
180-
node = _type._type._meta
183+
node = _type.type._meta if hasattr(_type.type, "_meta") else _type.type._of_type._meta
181184
if "id" in node.fields and not issubclass(
182185
node.model, (mongoengine.EmbeddedDocument,)
183186
):
184187
r.update({kv[0]: node.fields["id"]._type.of_type()})
188+
185189
return r
186190

187191
return reduce(get_reference_field, self.fields.items(), {})
188192

189193
@property
190194
def fields(self):
195+
self._type = get_type(self._type)
191196
return self._type._meta.fields
192197

193-
def get_queryset(self, model, info, only_fields=list(), **args):
198+
def get_queryset(self, model, info, required_fields=list(), skip=None, limit=None, reversed=False, **args):
194199
if args:
195200
reference_fields = get_model_reference_fields(self.model)
196201
hydrated_references = {}
197202
for arg_name, arg in args.copy().items():
198-
if arg_name in reference_fields:
199-
reference_obj = get_node_from_global_id(
200-
reference_fields[arg_name], info, args.pop(arg_name)
201-
)
203+
if arg_name in reference_fields and not isinstance(arg,
204+
mongoengine.base.metaclasses.TopLevelDocumentMetaclass):
205+
try:
206+
reference_obj = reference_fields[arg_name].document_type(pk=from_global_id(arg)[1])
207+
except TypeError:
208+
reference_obj = reference_fields[arg_name].document_type(pk=arg)
202209
hydrated_references[arg_name] = reference_obj
210+
elif arg_name == "id":
211+
hydrated_references["id"] = from_global_id(args.pop("id", None))[1]
203212
args.update(hydrated_references)
204213

205214
if self._get_queryset:
@@ -208,72 +217,138 @@ def get_queryset(self, model, info, only_fields=list(), **args):
208217
return queryset_or_filters
209218
else:
210219
args.update(queryset_or_filters)
220+
if limit is not None:
221+
if reversed:
222+
order_by = ""
223+
if self.order_by:
224+
order_by = self.order_by + ",-pk"
225+
else:
226+
order_by = "-pk"
227+
return model.objects(**args).no_dereference().only(*required_fields).order_by(order_by).skip(
228+
skip if skip else 0).limit(limit)
229+
else:
230+
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by).skip(
231+
skip if skip else 0).limit(limit)
232+
elif skip is not None:
233+
if reversed:
234+
order_by = ""
235+
if self.order_by:
236+
order_by = self.order_by + ",-pk"
237+
else:
238+
order_by = "-pk"
239+
return model.objects(**args).no_dereference().only(*required_fields).order_by(order_by).skip(
240+
skip)
241+
else:
242+
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by).skip(
243+
skip)
244+
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by)
211245

212-
return model.objects(**args).no_dereference().only(*only_fields).order_by(self.order_by)
213-
214-
def default_resolver(self, _root, info, only_fields=list(), **args):
246+
def default_resolver(self, _root, info, required_fields=list(), **args):
215247
args = args or {}
216-
217248
if _root is not None:
218249
field_name = to_snake_case(info.field_name)
219-
if getattr(_root, field_name, []) is not None:
220-
args["pk__in"] = [r.id for r in getattr(_root, field_name, [])]
221-
222-
connection_args = {
223-
"first": args.pop("first", None),
224-
"last": args.pop("last", None),
225-
"before": args.pop("before", None),
226-
"after": args.pop("after", None),
227-
}
250+
if field_name in _root._fields_ordered and not (isinstance(_root._fields[field_name].field,
251+
mongoengine.EmbeddedDocumentField) or
252+
isinstance(_root._fields[field_name].field,
253+
mongoengine.GenericEmbeddedDocumentField)):
254+
if getattr(_root, field_name, []) is not None:
255+
args["pk__in"] = [r.id for r in getattr(_root, field_name, [])]
228256

229257
_id = args.pop('id', None)
230258

231259
if _id is not None:
232260
args['pk'] = from_global_id(_id)[-1]
233-
261+
iterables = []
262+
list_length = 0
263+
skip = 0
264+
count = 0
265+
limit = None
266+
reverse = False
267+
first = args.pop("first", None)
268+
after = cursor_to_offset(args.pop("after", None))
269+
last = args.pop("last", None)
270+
before = cursor_to_offset(args.pop("before", None))
234271
if callable(getattr(self.model, "objects", None)):
235-
iterables = self.get_queryset(self.model, info, only_fields, **args)
236-
if isinstance(info, ResolveInfo):
237-
if not info.context:
238-
info.context = Context()
239-
info.context.queryset = iterables
240-
list_length = iterables.count()
241-
else:
242-
iterables = []
243-
list_length = 0
244-
245-
connection = connection_from_list_slice(
246-
list_slice=iterables,
247-
args=connection_args,
248-
list_length=list_length,
249-
list_slice_length=list_length,
250-
connection_type=self.type,
251-
edge_type=self.type.Edge,
252-
pageinfo_type=graphene.PageInfo,
253-
)
272+
if "pk__in" in args and args["pk__in"]:
273+
count = len(args["pk__in"])
274+
skip, limit, reverse = find_skip_and_limit(first=first, last=last, after=after, before=before,
275+
count=count)
276+
if limit:
277+
if reverse:
278+
args["pk__in"] = args["pk__in"][::-1][skip:skip + limit]
279+
else:
280+
args["pk__in"] = args["pk__in"][skip:skip + limit]
281+
elif skip:
282+
args["pk__in"] = args["pk__in"][skip:]
283+
iterables = self.get_queryset(self.model, info, required_fields, **args)
284+
list_length = len(iterables)
285+
if isinstance(info, ResolveInfo):
286+
if not info.context:
287+
info.context = Context()
288+
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)
289+
elif _root is None:
290+
count = self.get_queryset(self.model, info, required_fields, **args).count()
291+
if count != 0:
292+
skip, limit, reverse = find_skip_and_limit(first=first, after=after, last=last, before=before,
293+
count=count)
294+
iterables = self.get_queryset(self.model, info, required_fields, skip, limit, reverse, **args)
295+
list_length = len(iterables)
296+
if isinstance(info, ResolveInfo):
297+
if not info.context:
298+
info.context = Context()
299+
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)
300+
301+
elif _root is not None:
302+
field_name = to_snake_case(info.field_name)
303+
items = getattr(_root, field_name, [])
304+
count = len(items)
305+
skip, limit, reverse = find_skip_and_limit(first=first, last=last, after=after, before=before,
306+
count=count)
307+
if limit:
308+
if reverse:
309+
items = items[::-1][skip:skip + limit]
310+
else:
311+
items = items[skip:skip + limit]
312+
elif skip:
313+
items = items[skip:]
314+
iterables = items
315+
list_length = len(iterables)
316+
has_next_page = True if (0 if limit is None else limit) + (0 if skip is None else skip) < count else False
317+
has_previous_page = True if skip else False
318+
if reverse:
319+
iterables = list(iterables)
320+
iterables.reverse()
321+
skip = limit
322+
connection = connection_from_iterables(edges=iterables, start_offset=skip,
323+
has_previous_page=has_previous_page,
324+
has_next_page=has_next_page,
325+
connection_type=self.type,
326+
edge_type=self.type.Edge,
327+
pageinfo_type=graphene.PageInfo)
328+
254329
connection.iterable = iterables
255330
connection.list_length = list_length
256331
return connection
257332

258333
def chained_resolver(self, resolver, is_partial, root, info, **args):
259-
only_fields = list()
260-
for field in self.only_fields:
334+
required_fields = list()
335+
for field in self.required_fields:
261336
if field in self.model._fields_ordered:
262-
only_fields.append(field)
337+
required_fields.append(field)
263338
for field in get_query_fields(info):
264339
if to_snake_case(field) in self.model._fields_ordered:
265-
only_fields.append(to_snake_case(field))
340+
required_fields.append(to_snake_case(field))
266341
if not bool(args) or not is_partial:
267342
if isinstance(self.model, mongoengine.Document) or isinstance(self.model,
268343
mongoengine.base.metaclasses.TopLevelDocumentMetaclass):
269344
args_copy = args.copy()
270345
for arg_name, arg in args.copy().items():
271-
if arg_name not in self.model._fields_ordered:
346+
if arg_name not in self.model._fields_ordered + tuple(self.filter_args.keys()):
272347
args_copy.pop(arg_name)
273348
if isinstance(info, ResolveInfo):
274349
if not info.context:
275350
info.context = Context()
276-
info.context.queryset = self.get_queryset(self.model, info, only_fields, **args_copy)
351+
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args_copy)
277352
# XXX: Filter nested args
278353
resolved = resolver(root, info, **args)
279354
if resolved is not None:
@@ -282,9 +357,17 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
282357
return resolved
283358
elif not isinstance(resolved[0], DBRef):
284359
return resolved
360+
elif isinstance(resolved, QuerySet):
361+
args.update(resolved._query)
362+
args_copy = args.copy()
363+
for arg_name, arg in args.copy().items():
364+
if arg_name not in self.model._fields_ordered + ('first', 'last', 'before', 'after') + tuple(
365+
self.filter_args.keys()):
366+
args_copy.pop(arg_name)
367+
return self.default_resolver(root, info, required_fields, **args_copy)
285368
else:
286369
return resolved
287-
return self.default_resolver(root, info, only_fields, **args)
370+
return self.default_resolver(root, info, required_fields, **args)
288371

289372
@classmethod
290373
def connection_resolver(cls, resolver, connection_type, root, info, **args):

graphene_mongo/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
class Registry(object):
22
def __init__(self):
33
self._registry = {}
4+
self._registry_string_map = {}
45

56
def register(self, cls):
67
from .types import GrapheneMongoengineObjectTypes
@@ -13,6 +14,7 @@ def register(self, cls):
1314
)
1415
assert cls._meta.registry == self, "Registry for a Model have to match."
1516
self._registry[cls._meta.model] = cls
17+
self._registry_string_map[cls.__name__] = cls._meta.model.__name__
1618

1719
# Rescan all fields
1820
for model, cls in self._registry.items():

0 commit comments

Comments
 (0)