Skip to content

Commit 751b6b4

Browse files
Merge branch 'feat-retrieving-queried-fields-only'
# Conflicts: # graphene_mongo/fields.py
2 parents ef4c95c + ed4b1a1 commit 751b6b4

File tree

3 files changed

+48
-91
lines changed

3 files changed

+48
-91
lines changed

graphene_mongo/converter.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -237,12 +237,14 @@ def reference_resolver(root, *args, **kwargs):
237237

238238
if isinstance(field, mongoengine.GenericReferenceField):
239239
field_resolver = None
240+
required = False
240241
if field.db_field is not None:
242+
required = field.required
241243
resolver_function = getattr(_union, "resolve_" + field.db_field, None)
242244
if resolver_function and callable(resolver_function):
243245
field_resolver = resolver_function
244246
return graphene.Field(_union, resolver=field_resolver if field_resolver else reference_resolver,
245-
description=get_field_description(field, registry))
247+
description=get_field_description(field, registry), required=required)
246248

247249
return graphene.Field(_union)
248250

@@ -281,16 +283,18 @@ def dynamic_type():
281283
return graphene.Field(_type,
282284
description=get_field_description(field, registry))
283285
field_resolver = None
286+
required = False
284287
if field.db_field is not None:
288+
required = field.required
285289
resolver_function = getattr(_type, "resolve_" + field.db_field, None)
286290
if resolver_function and callable(resolver_function):
287291
field_resolver = resolver_function
288292
if isinstance(field, mongoengine.ReferenceField):
289293
return graphene.Field(_type, resolver=field_resolver if field_resolver else reference_resolver,
290-
description=get_field_description(field, registry))
294+
description=get_field_description(field, registry), required=required)
291295
else:
292-
return graphene.Field(_type, resolver=field_resolver if field_resolver else cached_reference_resolver(),
293-
description=get_field_description(field, registry))
296+
return graphene.Field(_type, resolver=field_resolver if field_resolver else cached_reference_resolver,
297+
description=get_field_description(field, registry), required=required)
294298

295299
return graphene.Dynamic(dynamic_type)
296300

@@ -314,14 +318,16 @@ def dynamic_type():
314318
if not _type:
315319
return None
316320
field_resolver = None
321+
required = False
317322
if field.db_field is not None:
323+
required = field.required
318324
resolver_function = getattr(_type, "resolve_" + field.db_field, None)
319325
if resolver_function and callable(resolver_function):
320326
field_resolver = resolver_function
321327
return graphene.Field(
322328
_type,
323329
resolver=field_resolver if field_resolver else lazy_resolver,
324-
description=get_field_description(field, registry),
330+
description=get_field_description(field, registry), required=required,
325331
)
326332

327333
return graphene.Dynamic(dynamic_type)

graphene_mongo/fields.py

Lines changed: 35 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
from graphene.types.argument import to_arguments
1717
from graphene.types.dynamic import Dynamic
1818
from graphene.types.structures import Structure
19-
from graphql_relay.connection.arrayconnection import cursor_to_offset
20-
from mongoengine import QuerySet
19+
from graphql_relay.connection.arrayconnection import connection_from_list_slice
2120

2221
from .advanced_types import (
2322
FileFieldType,
@@ -27,8 +26,7 @@
2726
)
2827
from .converter import convert_mongoengine_field, MongoEngineConversionError
2928
from .registry import get_global_registry
30-
from .utils import get_model_reference_fields, get_node_from_global_id, get_query_fields, find_skip_and_limit, \
31-
connection_from_iterables
29+
from .utils import get_model_reference_fields, get_node_from_global_id, get_query_fields
3230

3331

3432
class MongoengineConnectionField(ConnectionField):
@@ -178,7 +176,7 @@ def get_reference_field(r, kv):
178176
if callable(getattr(field, "get_type", None)):
179177
_type = field.get_type()
180178
if _type:
181-
node = _type._type._meta
179+
node = _type.type._meta if hasattr(_type.type, "_meta") else _type.type._of_type._meta
182180
if "id" in node.fields and not issubclass(
183181
node.model, (mongoengine.EmbeddedDocument,)
184182
):
@@ -192,12 +190,12 @@ def fields(self):
192190
self._type = get_type(self._type)
193191
return self._type._meta.fields
194192

195-
def get_queryset(self, model, info, required_fields=list(), skip=None, limit=None, reversed=False, **args):
193+
def get_queryset(self, model, info, required_fields=list(), **args):
196194
if args:
197195
reference_fields = get_model_reference_fields(self.model)
198196
hydrated_references = {}
199197
for arg_name, arg in args.copy().items():
200-
if arg_name in reference_fields:
198+
if arg_name in reference_fields and isinstance(arg, str):
201199
reference_obj = get_node_from_global_id(
202200
reference_fields[arg_name], info, args.pop(arg_name)
203201
)
@@ -210,94 +208,50 @@ def get_queryset(self, model, info, required_fields=list(), skip=None, limit=Non
210208
return queryset_or_filters
211209
else:
212210
args.update(queryset_or_filters)
213-
if limit is not None:
214-
if reversed:
215-
order_by = ""
216-
if self.order_by:
217-
order_by = self.order_by + ",-pk"
218-
else:
219-
order_by = "-pk"
220-
return model.objects(**args).no_dereference().only(*required_fields).order_by(order_by).skip(
221-
skip if skip else 0).limit(limit)
222-
else:
223-
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by).skip(
224-
skip if skip else 0).limit(limit)
225-
elif skip is not None:
226-
if reversed:
227-
order_by = ""
228-
if self.order_by:
229-
order_by = self.order_by + ",-pk"
230-
else:
231-
order_by = "-pk"
232-
return model.objects(**args).no_dereference().only(*required_fields).order_by(order_by).skip(
233-
skip)
234-
else:
235-
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by).skip(
236-
skip)
211+
237212
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by)
238213

239214
def default_resolver(self, _root, info, required_fields=list(), **args):
240215
args = args or {}
216+
241217
if _root is not None:
242218
field_name = to_snake_case(info.field_name)
243219
if field_name in _root._fields_ordered:
244220
if getattr(_root, field_name, []) is not None:
245221
args["pk__in"] = [r.id for r in getattr(_root, field_name, [])]
246222

223+
connection_args = {
224+
"first": args.pop("first", None),
225+
"last": args.pop("last", None),
226+
"before": args.pop("before", None),
227+
"after": args.pop("after", None),
228+
}
229+
247230
_id = args.pop('id', None)
231+
248232
if _id is not None:
249233
args['pk'] = from_global_id(_id)[-1]
250-
iterables = []
251-
list_length = 0
252-
skip = 0
253-
count = 0
254-
limit = None
255-
reverse = False
234+
256235
if callable(getattr(self.model, "objects", None)):
257-
first = args.pop("first", None)
258-
after = cursor_to_offset(args.pop("after", None))
259-
last = args.pop("last", None)
260-
before = cursor_to_offset(args.pop("before", None))
261-
if "pk__in" in args and args["pk__in"]:
262-
count = len(args["pk__in"])
263-
skip, limit, reverse = find_skip_and_limit(first=first, last=last, after=after, before=before,
264-
count=count)
265-
if limit:
266-
if reverse:
267-
args["pk__in"] = args["pk__in"][::-1][skip:skip + limit]
268-
else:
269-
args["pk__in"] = args["pk__in"][skip:skip + limit]
270-
elif skip:
271-
args["pk__in"] = args["pk__in"][skip:]
272-
iterables = self.get_queryset(self.model, info, required_fields, **args)
273-
list_length = len(iterables)
274-
if isinstance(info, ResolveInfo):
275-
if not info.context:
276-
info.context = Context()
277-
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)
278-
else:
279-
count = self.get_queryset(self.model, info, required_fields, **args).count()
280-
if count != 0:
281-
skip, limit, reverse = find_skip_and_limit(first=first, after=after, last=last, before=before,
282-
count=count)
283-
iterables = self.get_queryset(self.model, info, required_fields, skip, limit, reverse, **args)
284-
list_length = len(iterables)
285-
if isinstance(info, ResolveInfo):
286-
if not info.context:
287-
info.context = Context()
288-
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)
289-
has_next_page = True if (0 if limit is None else limit) + (0 if skip is None else skip) < count else False
290-
has_previous_page = True if skip else False
291-
if reverse:
292-
iterables = list(iterables)
293-
iterables.reverse()
294-
skip = limit
295-
connection = connection_from_iterables(edges=iterables, start_offset=skip,
296-
has_previous_page=has_previous_page,
297-
has_next_page=has_next_page,
298-
connection_type=self.type,
299-
edge_type=self.type.Edge,
300-
pageinfo_type=graphene.PageInfo)
236+
iterables = self.get_queryset(self.model, info, required_fields, **args)
237+
if isinstance(info, ResolveInfo):
238+
if not info.context:
239+
info.context = Context()
240+
info.context.queryset = iterables
241+
list_length = iterables.count()
242+
else:
243+
iterables = []
244+
list_length = 0
245+
246+
connection = connection_from_list_slice(
247+
list_slice=iterables,
248+
args=connection_args,
249+
list_length=list_length,
250+
list_slice_length=list_length,
251+
connection_type=self.type,
252+
edge_type=self.type.Edge,
253+
pageinfo_type=graphene.PageInfo,
254+
)
301255
connection.iterable = iterables
302256
connection.list_length = list_length
303257
return connection
@@ -329,9 +283,6 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
329283
return resolved
330284
elif not isinstance(resolved[0], DBRef):
331285
return resolved
332-
elif isinstance(resolved, QuerySet):
333-
args.update(resolved._query)
334-
return self.default_resolver(root, info, required_fields, **args)
335286
else:
336287
return resolved
337288
return self.default_resolver(root, info, required_fields, **args)

graphene_mongo/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .utils import get_model_fields, is_valid_mongoengine_model, get_query_fields
1414

1515

16-
def construct_fields(model, registry, required_fields, exclude_fields):
16+
def construct_fields(model, registry, only_fields, exclude_fields):
1717
"""
1818
Args:
1919
model (mongoengine.Document):
@@ -29,7 +29,7 @@ def construct_fields(model, registry, required_fields, exclude_fields):
2929
fields = OrderedDict()
3030
self_referenced = OrderedDict()
3131
for name, field in _model_fields.items():
32-
is_not_in_only = required_fields and name not in required_fields
32+
is_not_in_only = only_fields and name not in only_fields
3333
is_excluded = name in exclude_fields
3434
if is_not_in_only or is_excluded:
3535
# We skip this field if we specify required_fields and is not

0 commit comments

Comments
 (0)