5
5
6
6
import graphene
7
7
import mongoengine
8
- from bson import DBRef
8
+ from bson import DBRef , ObjectId
9
9
from graphene import Context
10
10
from graphene .types .utils import get_type
11
11
from graphene .utils .str_converters import to_snake_case
12
- from graphql import ResolveInfo
12
+ from graphql import GraphQLResolveInfo
13
13
from mongoengine .base import get_document
14
14
from promise import Promise
15
15
from graphql_relay import from_global_id
@@ -168,7 +168,7 @@ def filter_args(self):
168
168
}
169
169
filter_type = advanced_filter_types .get (each , filter_type )
170
170
filter_args [field + "__" + each ] = graphene .Argument (
171
- type = filter_type
171
+ type_ = filter_type
172
172
)
173
173
return filter_args
174
174
@@ -215,7 +215,10 @@ def fields(self):
215
215
self ._type = get_type (self ._type )
216
216
return self ._type ._meta .fields
217
217
218
- def get_queryset (self , model , info , required_fields = list (), skip = None , limit = None , reversed = False , ** args ):
218
+ def get_queryset (self , model , info , required_fields = None , skip = None , limit = None , reversed = False , ** args ):
219
+ if required_fields is None :
220
+ required_fields = list ()
221
+
219
222
if args :
220
223
reference_fields = get_model_reference_fields (self .model )
221
224
hydrated_references = {}
@@ -276,8 +279,13 @@ def get_queryset(self, model, info, required_fields=list(), skip=None, limit=Non
276
279
skip )
277
280
return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (self .order_by )
278
281
279
- def default_resolver (self , _root , info , required_fields = list (), ** args ):
282
+ def default_resolver (self , _root , info , required_fields = None , ** args ):
283
+ if required_fields is None :
284
+ required_fields = list ()
280
285
args = args or {}
286
+ for key , value in dict (args ).items ():
287
+ if value is None :
288
+ del args [key ]
281
289
if _root is not None :
282
290
field_name = to_snake_case (info .field_name )
283
291
if not hasattr (_root , "_fields_ordered" ):
@@ -301,9 +309,13 @@ def default_resolver(self, _root, info, required_fields=list(), **args):
301
309
limit = None
302
310
reverse = False
303
311
first = args .pop ("first" , None )
304
- after = cursor_to_offset (args .pop ("after" , None ))
312
+ after = args .pop ("after" , None )
313
+ if after :
314
+ after = cursor_to_offset (after )
305
315
last = args .pop ("last" , None )
306
- before = cursor_to_offset (args .pop ("before" , None ))
316
+ before = args .pop ("before" , None )
317
+ if before :
318
+ before = cursor_to_offset (before )
307
319
if callable (getattr (self .model , "objects" , None )):
308
320
if "pk__in" in args and args ["pk__in" ]:
309
321
count = len (args ["pk__in" ])
@@ -318,20 +330,32 @@ def default_resolver(self, _root, info, required_fields=list(), **args):
318
330
args ["pk__in" ] = args ["pk__in" ][skip :]
319
331
iterables = self .get_queryset (self .model , info , required_fields , ** args )
320
332
list_length = len (iterables )
321
- if isinstance (info , ResolveInfo ):
333
+ if isinstance (info , GraphQLResolveInfo ):
322
334
if not info .context :
323
- info . context = Context ()
335
+ info = info . _replace ( context = Context () )
324
336
info .context .queryset = self .get_queryset (self .model , info , required_fields , ** args )
325
337
elif _root is None or args :
326
- count = self .get_queryset (self .model , info , required_fields , ** args ).count ()
338
+ args_copy = args .copy ()
339
+ for key in args .copy ():
340
+ if key not in self .model ._fields_ordered :
341
+ args_copy .pop (key )
342
+ elif isinstance (getattr (self .model , key ),
343
+ mongoengine .fields .ReferenceField ) or isinstance (getattr (self .model , key ),
344
+ mongoengine .fields .GenericReferenceField ) or isinstance (
345
+ getattr (self .model , key ),
346
+ mongoengine .fields .LazyReferenceField ) or isinstance (getattr (self .model , key ),
347
+ mongoengine .fields .CachedReferenceField ):
348
+ if not isinstance (args_copy [key ], ObjectId ):
349
+ args_copy [key ] = from_global_id (args_copy [key ])[1 ]
350
+ count = mongoengine .get_db ()[self .model ._get_collection_name ()].find (args_copy ).count ()
327
351
if count != 0 :
328
352
skip , limit , reverse = find_skip_and_limit (first = first , after = after , last = last , before = before ,
329
353
count = count )
330
354
iterables = self .get_queryset (self .model , info , required_fields , skip , limit , reverse , ** args )
331
355
list_length = len (iterables )
332
- if isinstance (info , ResolveInfo ):
356
+ if isinstance (info , GraphQLResolveInfo ):
333
357
if not info .context :
334
- info . context = Context ()
358
+ info = info . _replace ( context = Context () )
335
359
info .context .queryset = self .get_queryset (self .model , info , required_fields , ** args )
336
360
337
361
elif _root is not None :
@@ -367,6 +391,9 @@ def default_resolver(self, _root, info, required_fields=list(), **args):
367
391
return connection
368
392
369
393
def chained_resolver (self , resolver , is_partial , root , info , ** args ):
394
+ for key , value in dict (args ).items ():
395
+ if value is None :
396
+ del args [key ]
370
397
required_fields = list ()
371
398
for field in self .required_fields :
372
399
if field in self .model ._fields_ordered :
@@ -378,13 +405,15 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
378
405
if not bool (args ) or not is_partial :
379
406
if isinstance (self .model , mongoengine .Document ) or isinstance (self .model ,
380
407
mongoengine .base .metaclasses .TopLevelDocumentMetaclass ):
408
+
381
409
for arg_name , arg in args .copy ().items ():
382
410
if arg_name not in self .model ._fields_ordered + tuple (self .filter_args .keys ()):
383
411
args_copy .pop (arg_name )
384
- if isinstance (info , ResolveInfo ):
412
+ if isinstance (info , GraphQLResolveInfo ):
385
413
if not info .context :
386
- info .context = Context ()
387
- info .context .queryset = self .get_queryset (self .model , info , required_fields , ** args_copy )
414
+ info = info ._replace (context = Context ())
415
+ info .context .queryset = self .get_queryset (self .model , info , required_fields , ** args )
416
+
388
417
# XXX: Filter nested args
389
418
resolved = resolver (root , info , ** args )
390
419
if resolved is not None :
@@ -405,7 +434,7 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
405
434
if arg_name == '_id' and isinstance (arg , dict ):
406
435
operation = list (arg .keys ())[0 ]
407
436
args_copy ['pk' + operation .replace ('$' , '__' )] = arg [operation ]
408
- if '.' in arg_name :
437
+ if not isinstance ( arg , ObjectId ) and '.' in arg_name :
409
438
operation = list (arg .keys ())[0 ]
410
439
args_copy [arg_name .replace ('.' , '__' ) + operation .replace ('$' , '__' )] = arg [operation ]
411
440
else :
@@ -415,6 +444,8 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
415
444
args_copy [arg_name + operation .replace ('$' , '__' )] = arg [operation ]
416
445
del args_copy [arg_name ]
417
446
return self .default_resolver (root , info , required_fields , ** args_copy )
447
+ elif isinstance (resolved , Promise ):
448
+ return resolved .value
418
449
else :
419
450
return resolved
420
451
return self .default_resolver (root , info , required_fields , ** args )
0 commit comments