Skip to content

Commit abb7bf5

Browse files
Fix: While querying embedded document with key value as ObjectID
1 parent ecb2012 commit abb7bf5

File tree

1 file changed

+10
-31
lines changed

1 file changed

+10
-31
lines changed

graphene_mongo/fields.py

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import graphene
77
import mongoengine
8-
from bson import DBRef
8+
from bson import DBRef, ObjectId
99
from graphene import Context
1010
from graphene.types.utils import get_type
1111
from graphene.utils.str_converters import to_snake_case
@@ -215,7 +215,10 @@ def fields(self):
215215
self._type = get_type(self._type)
216216
return self._type._meta.fields
217217

218-
def get_queryset(self, model, info, required_fields=list(), skip=None, limit=None, reversed=False, **args):
218+
def get_queryset(self, model, info, required_fields=None, skip=None, limit=None, reversed=False, **args):
219+
if required_fields is None:
220+
required_fields = list()
221+
219222
if args:
220223
reference_fields = get_model_reference_fields(self.model)
221224
hydrated_references = {}
@@ -276,7 +279,9 @@ def get_queryset(self, model, info, required_fields=list(), skip=None, limit=Non
276279
skip)
277280
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by)
278281

279-
def default_resolver(self, _root, info, required_fields=list(), **args):
282+
def default_resolver(self, _root, info, required_fields=None, **args):
283+
if required_fields is None:
284+
required_fields = list()
280285
args = args or {}
281286
for key, value in dict(args).items():
282287
if value is None:
@@ -400,39 +405,13 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
400405
if isinstance(self.model, mongoengine.Document) or isinstance(self.model,
401406
mongoengine.base.metaclasses.TopLevelDocumentMetaclass):
402407

403-
skip = 0
404-
count = 0
405-
limit = None
406-
reverse = False
407-
first = args_copy.get("first")
408-
after = args_copy.get("after")
409-
if after:
410-
after = cursor_to_offset(after)
411-
last = args_copy.get("last")
412-
before = args_copy.get("before")
413408
for arg_name, arg in args.copy().items():
414409
if arg_name not in self.model._fields_ordered + tuple(self.filter_args.keys()):
415410
args_copy.pop(arg_name)
416411
if isinstance(info, GraphQLResolveInfo):
417412
if not info.context:
418413
info = info._replace(context=Context())
419-
args_count_copy = args.copy()
420-
for key in args.copy():
421-
if key not in self.model._fields_ordered:
422-
args_count_copy.pop(key)
423-
elif isinstance(getattr(self.model, key),
424-
mongoengine.fields.ReferenceField) or isinstance(getattr(self.model, key),
425-
mongoengine.fields.GenericReferenceField) or isinstance(
426-
getattr(self.model, key),
427-
mongoengine.fields.LazyReferenceField) or isinstance(getattr(self.model, key),
428-
mongoengine.fields.CachedReferenceField):
429-
args_count_copy[key] = from_global_id(args_count_copy[key])[1]
430-
count = mongoengine.get_db()[self.model._get_collection_name()].find(args_count_copy).count()
431-
if count != 0:
432-
skip, limit, reverse = find_skip_and_limit(first=first, after=after, last=last,
433-
before=before,
434-
count=count)
435-
info.context.queryset = self.get_queryset(self.model, info, required_fields, skip, limit, reverse)
414+
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)
436415

437416
# XXX: Filter nested args
438417
resolved = resolver(root, info, **args)
@@ -454,7 +433,7 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
454433
if arg_name == '_id' and isinstance(arg, dict):
455434
operation = list(arg.keys())[0]
456435
args_copy['pk' + operation.replace('$', '__')] = arg[operation]
457-
if '.' in arg_name:
436+
if '.' in arg_name and not isinstance(arg, ObjectId):
458437
operation = list(arg.keys())[0]
459438
args_copy[arg_name.replace('.', '__') + operation.replace('$', '__')] = arg[operation]
460439
else:

0 commit comments

Comments
 (0)