7
7
import mongoengine
8
8
from bson import DBRef
9
9
from graphene import Context
10
+ from graphene .types .utils import get_type
10
11
from graphene .utils .str_converters import to_snake_case
11
12
from graphql import ResolveInfo
12
13
from promise import Promise
15
16
from graphene .types .argument import to_arguments
16
17
from graphene .types .dynamic import Dynamic
17
18
from graphene .types .structures import Structure
18
- from graphql_relay .connection .arrayconnection import connection_from_list_slice
19
+ from graphql_relay .connection .arrayconnection import cursor_to_offset
20
+ from mongoengine import QuerySet
19
21
20
22
from .advanced_types import (
21
23
FileFieldType ,
25
27
)
26
28
from .converter import convert_mongoengine_field , MongoEngineConversionError
27
29
from .registry import get_global_registry
28
- from .utils import get_model_reference_fields , get_node_from_global_id , get_query_fields
30
+ from .utils import get_model_reference_fields , get_query_fields , find_skip_and_limit , \
31
+ connection_from_iterables
29
32
30
33
31
34
class MongoengineConnectionField (ConnectionField ):
@@ -64,10 +67,8 @@ def order_by(self):
64
67
return self .node_type ._meta .order_by
65
68
66
69
@property
67
- def only_fields (self ):
68
- if isinstance (self .node_type ._meta .only_fields , str ):
69
- return self .node_type ._meta .only_fields .split ("," )
70
- return list ()
70
+ def required_fields (self ):
71
+ return tuple (set (self .node_type ._meta .required_fields + self .node_type ._meta .only_fields ))
71
72
72
73
@property
73
74
def registry (self ):
@@ -118,11 +119,13 @@ def is_filterable(k):
118
119
),
119
120
):
120
121
return False
122
+ if getattr (converted , "type" , None ) and getattr (converted .type , "_of_type" , None ) and issubclass (
123
+ (get_type (converted .type .of_type )), graphene .Union ):
124
+ return False
121
125
if isinstance (converted , (graphene .List )) and issubclass (
122
126
getattr (converted , "_of_type" , None ), graphene .Union
123
127
):
124
128
return False
125
-
126
129
return True
127
130
128
131
def get_filter_type (_type ):
@@ -177,29 +180,35 @@ def get_reference_field(r, kv):
177
180
if callable (getattr (field , "get_type" , None )):
178
181
_type = field .get_type ()
179
182
if _type :
180
- node = _type ._type ._meta
183
+ node = _type .type . _meta if hasattr ( _type . type , "_meta" ) else _type . type . _of_type ._meta
181
184
if "id" in node .fields and not issubclass (
182
185
node .model , (mongoengine .EmbeddedDocument ,)
183
186
):
184
187
r .update ({kv [0 ]: node .fields ["id" ]._type .of_type ()})
188
+
185
189
return r
186
190
187
191
return reduce (get_reference_field , self .fields .items (), {})
188
192
189
193
@property
190
194
def fields (self ):
195
+ self ._type = get_type (self ._type )
191
196
return self ._type ._meta .fields
192
197
193
- def get_queryset (self , model , info , only_fields = list (), ** args ):
198
+ def get_queryset (self , model , info , required_fields = list (), skip = None , limit = None , reversed = False , ** args ):
194
199
if args :
195
200
reference_fields = get_model_reference_fields (self .model )
196
201
hydrated_references = {}
197
202
for arg_name , arg in args .copy ().items ():
198
- if arg_name in reference_fields :
199
- reference_obj = get_node_from_global_id (
200
- reference_fields [arg_name ], info , args .pop (arg_name )
201
- )
203
+ if arg_name in reference_fields and not isinstance (arg ,
204
+ mongoengine .base .metaclasses .TopLevelDocumentMetaclass ):
205
+ try :
206
+ reference_obj = reference_fields [arg_name ].document_type (pk = from_global_id (arg )[1 ])
207
+ except TypeError :
208
+ reference_obj = reference_fields [arg_name ].document_type (pk = arg )
202
209
hydrated_references [arg_name ] = reference_obj
210
+ elif arg_name == "id" :
211
+ hydrated_references ["id" ] = from_global_id (args .pop ("id" , None ))[1 ]
203
212
args .update (hydrated_references )
204
213
205
214
if self ._get_queryset :
@@ -208,72 +217,138 @@ def get_queryset(self, model, info, only_fields=list(), **args):
208
217
return queryset_or_filters
209
218
else :
210
219
args .update (queryset_or_filters )
220
+ if limit is not None :
221
+ if reversed :
222
+ order_by = ""
223
+ if self .order_by :
224
+ order_by = self .order_by + ",-pk"
225
+ else :
226
+ order_by = "-pk"
227
+ return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (order_by ).skip (
228
+ skip if skip else 0 ).limit (limit )
229
+ else :
230
+ return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (self .order_by ).skip (
231
+ skip if skip else 0 ).limit (limit )
232
+ elif skip is not None :
233
+ if reversed :
234
+ order_by = ""
235
+ if self .order_by :
236
+ order_by = self .order_by + ",-pk"
237
+ else :
238
+ order_by = "-pk"
239
+ return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (order_by ).skip (
240
+ skip )
241
+ else :
242
+ return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (self .order_by ).skip (
243
+ skip )
244
+ return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (self .order_by )
211
245
212
- return model .objects (** args ).no_dereference ().only (* only_fields ).order_by (self .order_by )
213
-
214
- def default_resolver (self , _root , info , only_fields = list (), ** args ):
246
+ def default_resolver (self , _root , info , required_fields = list (), ** args ):
215
247
args = args or {}
216
-
217
248
if _root is not None :
218
249
field_name = to_snake_case (info .field_name )
219
- if getattr (_root , field_name , []) is not None :
220
- args ["pk__in" ] = [r .id for r in getattr (_root , field_name , [])]
221
-
222
- connection_args = {
223
- "first" : args .pop ("first" , None ),
224
- "last" : args .pop ("last" , None ),
225
- "before" : args .pop ("before" , None ),
226
- "after" : args .pop ("after" , None ),
227
- }
250
+ if field_name in _root ._fields_ordered and not (isinstance (_root ._fields [field_name ].field ,
251
+ mongoengine .EmbeddedDocumentField ) or
252
+ isinstance (_root ._fields [field_name ].field ,
253
+ mongoengine .GenericEmbeddedDocumentField )):
254
+ if getattr (_root , field_name , []) is not None :
255
+ args ["pk__in" ] = [r .id for r in getattr (_root , field_name , [])]
228
256
229
257
_id = args .pop ('id' , None )
230
258
231
259
if _id is not None :
232
260
args ['pk' ] = from_global_id (_id )[- 1 ]
233
-
261
+ iterables = []
262
+ list_length = 0
263
+ skip = 0
264
+ count = 0
265
+ limit = None
266
+ reverse = False
267
+ first = args .pop ("first" , None )
268
+ after = cursor_to_offset (args .pop ("after" , None ))
269
+ last = args .pop ("last" , None )
270
+ before = cursor_to_offset (args .pop ("before" , None ))
234
271
if callable (getattr (self .model , "objects" , None )):
235
- iterables = self .get_queryset (self .model , info , only_fields , ** args )
236
- if isinstance (info , ResolveInfo ):
237
- if not info .context :
238
- info .context = Context ()
239
- info .context .queryset = iterables
240
- list_length = iterables .count ()
241
- else :
242
- iterables = []
243
- list_length = 0
244
-
245
- connection = connection_from_list_slice (
246
- list_slice = iterables ,
247
- args = connection_args ,
248
- list_length = list_length ,
249
- list_slice_length = list_length ,
250
- connection_type = self .type ,
251
- edge_type = self .type .Edge ,
252
- pageinfo_type = graphene .PageInfo ,
253
- )
272
+ if "pk__in" in args and args ["pk__in" ]:
273
+ count = len (args ["pk__in" ])
274
+ skip , limit , reverse = find_skip_and_limit (first = first , last = last , after = after , before = before ,
275
+ count = count )
276
+ if limit :
277
+ if reverse :
278
+ args ["pk__in" ] = args ["pk__in" ][::- 1 ][skip :skip + limit ]
279
+ else :
280
+ args ["pk__in" ] = args ["pk__in" ][skip :skip + limit ]
281
+ elif skip :
282
+ args ["pk__in" ] = args ["pk__in" ][skip :]
283
+ iterables = self .get_queryset (self .model , info , required_fields , ** args )
284
+ list_length = len (iterables )
285
+ if isinstance (info , ResolveInfo ):
286
+ if not info .context :
287
+ info .context = Context ()
288
+ info .context .queryset = self .get_queryset (self .model , info , required_fields , ** args )
289
+ elif _root is None :
290
+ count = self .get_queryset (self .model , info , required_fields , ** args ).count ()
291
+ if count != 0 :
292
+ skip , limit , reverse = find_skip_and_limit (first = first , after = after , last = last , before = before ,
293
+ count = count )
294
+ iterables = self .get_queryset (self .model , info , required_fields , skip , limit , reverse , ** args )
295
+ list_length = len (iterables )
296
+ if isinstance (info , ResolveInfo ):
297
+ if not info .context :
298
+ info .context = Context ()
299
+ info .context .queryset = self .get_queryset (self .model , info , required_fields , ** args )
300
+
301
+ elif _root is not None :
302
+ field_name = to_snake_case (info .field_name )
303
+ items = getattr (_root , field_name , [])
304
+ count = len (items )
305
+ skip , limit , reverse = find_skip_and_limit (first = first , last = last , after = after , before = before ,
306
+ count = count )
307
+ if limit :
308
+ if reverse :
309
+ items = items [::- 1 ][skip :skip + limit ]
310
+ else :
311
+ items = items [skip :skip + limit ]
312
+ elif skip :
313
+ items = items [skip :]
314
+ iterables = items
315
+ list_length = len (iterables )
316
+ has_next_page = True if (0 if limit is None else limit ) + (0 if skip is None else skip ) < count else False
317
+ has_previous_page = True if skip else False
318
+ if reverse :
319
+ iterables = list (iterables )
320
+ iterables .reverse ()
321
+ skip = limit
322
+ connection = connection_from_iterables (edges = iterables , start_offset = skip ,
323
+ has_previous_page = has_previous_page ,
324
+ has_next_page = has_next_page ,
325
+ connection_type = self .type ,
326
+ edge_type = self .type .Edge ,
327
+ pageinfo_type = graphene .PageInfo )
328
+
254
329
connection .iterable = iterables
255
330
connection .list_length = list_length
256
331
return connection
257
332
258
333
def chained_resolver (self , resolver , is_partial , root , info , ** args ):
259
- only_fields = list ()
260
- for field in self .only_fields :
334
+ required_fields = list ()
335
+ for field in self .required_fields :
261
336
if field in self .model ._fields_ordered :
262
- only_fields .append (field )
337
+ required_fields .append (field )
263
338
for field in get_query_fields (info ):
264
339
if to_snake_case (field ) in self .model ._fields_ordered :
265
- only_fields .append (to_snake_case (field ))
340
+ required_fields .append (to_snake_case (field ))
266
341
if not bool (args ) or not is_partial :
267
342
if isinstance (self .model , mongoengine .Document ) or isinstance (self .model ,
268
343
mongoengine .base .metaclasses .TopLevelDocumentMetaclass ):
269
344
args_copy = args .copy ()
270
345
for arg_name , arg in args .copy ().items ():
271
- if arg_name not in self .model ._fields_ordered :
346
+ if arg_name not in self .model ._fields_ordered + tuple ( self . filter_args . keys ()) :
272
347
args_copy .pop (arg_name )
273
348
if isinstance (info , ResolveInfo ):
274
349
if not info .context :
275
350
info .context = Context ()
276
- info .context .queryset = self .get_queryset (self .model , info , only_fields , ** args_copy )
351
+ info .context .queryset = self .get_queryset (self .model , info , required_fields , ** args_copy )
277
352
# XXX: Filter nested args
278
353
resolved = resolver (root , info , ** args )
279
354
if resolved is not None :
@@ -282,9 +357,17 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
282
357
return resolved
283
358
elif not isinstance (resolved [0 ], DBRef ):
284
359
return resolved
360
+ elif isinstance (resolved , QuerySet ):
361
+ args .update (resolved ._query )
362
+ args_copy = args .copy ()
363
+ for arg_name , arg in args .copy ().items ():
364
+ if arg_name not in self .model ._fields_ordered + ('first' , 'last' , 'before' , 'after' ) + tuple (
365
+ self .filter_args .keys ()):
366
+ args_copy .pop (arg_name )
367
+ return self .default_resolver (root , info , required_fields , ** args_copy )
285
368
else :
286
369
return resolved
287
- return self .default_resolver (root , info , only_fields , ** args )
370
+ return self .default_resolver (root , info , required_fields , ** args )
288
371
289
372
@classmethod
290
373
def connection_resolver (cls , resolver , connection_type , root , info , ** args ):
0 commit comments