5
5
6
6
import graphene
7
7
import mongoengine
8
- from bson import DBRef
8
+ from bson import DBRef , ObjectId
9
9
from graphene import Context
10
10
from graphene .types .utils import get_type
11
11
from graphene .utils .str_converters import to_snake_case
@@ -215,7 +215,10 @@ def fields(self):
215
215
self ._type = get_type (self ._type )
216
216
return self ._type ._meta .fields
217
217
218
- def get_queryset (self , model , info , required_fields = list (), skip = None , limit = None , reversed = False , ** args ):
218
+ def get_queryset (self , model , info , required_fields = None , skip = None , limit = None , reversed = False , ** args ):
219
+ if required_fields is None :
220
+ required_fields = list ()
221
+
219
222
if args :
220
223
reference_fields = get_model_reference_fields (self .model )
221
224
hydrated_references = {}
@@ -276,7 +279,9 @@ def get_queryset(self, model, info, required_fields=list(), skip=None, limit=Non
276
279
skip )
277
280
return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (self .order_by )
278
281
279
- def default_resolver (self , _root , info , required_fields = list (), ** args ):
282
+ def default_resolver (self , _root , info , required_fields = None , ** args ):
283
+ if required_fields is None :
284
+ required_fields = list ()
280
285
args = args or {}
281
286
for key , value in dict (args ).items ():
282
287
if value is None :
@@ -400,39 +405,13 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
400
405
if isinstance (self .model , mongoengine .Document ) or isinstance (self .model ,
401
406
mongoengine .base .metaclasses .TopLevelDocumentMetaclass ):
402
407
403
- skip = 0
404
- count = 0
405
- limit = None
406
- reverse = False
407
- first = args_copy .get ("first" )
408
- after = args_copy .get ("after" )
409
- if after :
410
- after = cursor_to_offset (after )
411
- last = args_copy .get ("last" )
412
- before = args_copy .get ("before" )
413
408
for arg_name , arg in args .copy ().items ():
414
409
if arg_name not in self .model ._fields_ordered + tuple (self .filter_args .keys ()):
415
410
args_copy .pop (arg_name )
416
411
if isinstance (info , GraphQLResolveInfo ):
417
412
if not info .context :
418
413
info = info ._replace (context = Context ())
419
- args_count_copy = args .copy ()
420
- for key in args .copy ():
421
- if key not in self .model ._fields_ordered :
422
- args_count_copy .pop (key )
423
- elif isinstance (getattr (self .model , key ),
424
- mongoengine .fields .ReferenceField ) or isinstance (getattr (self .model , key ),
425
- mongoengine .fields .GenericReferenceField ) or isinstance (
426
- getattr (self .model , key ),
427
- mongoengine .fields .LazyReferenceField ) or isinstance (getattr (self .model , key ),
428
- mongoengine .fields .CachedReferenceField ):
429
- args_count_copy [key ] = from_global_id (args_count_copy [key ])[1 ]
430
- count = mongoengine .get_db ()[self .model ._get_collection_name ()].find (args_count_copy ).count ()
431
- if count != 0 :
432
- skip , limit , reverse = find_skip_and_limit (first = first , after = after , last = last ,
433
- before = before ,
434
- count = count )
435
- info .context .queryset = self .get_queryset (self .model , info , required_fields , skip , limit , reverse )
414
+ info .context .queryset = self .get_queryset (self .model , info , required_fields , ** args )
436
415
437
416
# XXX: Filter nested args
438
417
resolved = resolver (root , info , ** args )
@@ -454,7 +433,7 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
454
433
if arg_name == '_id' and isinstance (arg , dict ):
455
434
operation = list (arg .keys ())[0 ]
456
435
args_copy ['pk' + operation .replace ('$' , '__' )] = arg [operation ]
457
- if '.' in arg_name :
436
+ if '.' in arg_name and not isinstance ( arg , ObjectId ) :
458
437
operation = list (arg .keys ())[0 ]
459
438
args_copy [arg_name .replace ('.' , '__' ) + operation .replace ('$' , '__' )] = arg [operation ]
460
439
else :
0 commit comments