7
7
import mongoengine
8
8
from bson import DBRef
9
9
from graphene import Context
10
+ from graphene .types .utils import get_type
10
11
from graphene .utils .str_converters import to_snake_case
11
12
from graphql import ResolveInfo
12
13
from promise import Promise
15
16
from graphene .types .argument import to_arguments
16
17
from graphene .types .dynamic import Dynamic
17
18
from graphene .types .structures import Structure
18
- from graphql_relay .connection .arrayconnection import connection_from_list_slice
19
+ from graphql_relay .connection .arrayconnection import cursor_to_offset
20
+ from mongoengine import QuerySet
19
21
20
22
from .advanced_types import (
21
23
FileFieldType ,
25
27
)
26
28
from .converter import convert_mongoengine_field , MongoEngineConversionError
27
29
from .registry import get_global_registry
28
- from .utils import get_model_reference_fields , get_node_from_global_id , get_query_fields
30
+ from .utils import get_model_reference_fields , get_query_fields , find_skip_and_limit , \
31
+ connection_from_iterables
29
32
30
33
31
34
class MongoengineConnectionField (ConnectionField ):
@@ -64,10 +67,8 @@ def order_by(self):
64
67
return self .node_type ._meta .order_by
65
68
66
69
@property
67
- def only_fields (self ):
68
- if isinstance (self .node_type ._meta .only_fields , str ):
69
- return self .node_type ._meta .only_fields .split ("," )
70
- return list ()
70
+ def required_fields (self ):
71
+ return tuple (set (self .node_type ._meta .required_fields + self .node_type ._meta .only_fields ))
71
72
72
73
@property
73
74
def registry (self ):
@@ -118,11 +119,13 @@ def is_filterable(k):
118
119
),
119
120
):
120
121
return False
122
+ if getattr (converted , "type" , None ) and getattr (converted .type , "_of_type" , None ) and issubclass (
123
+ (get_type (converted .type .of_type )), graphene .Union ):
124
+ return False
121
125
if isinstance (converted , (graphene .List )) and issubclass (
122
126
getattr (converted , "_of_type" , None ), graphene .Union
123
127
):
124
128
return False
125
-
126
129
return True
127
130
128
131
def get_filter_type (_type ):
@@ -177,29 +180,35 @@ def get_reference_field(r, kv):
177
180
if callable (getattr (field , "get_type" , None )):
178
181
_type = field .get_type ()
179
182
if _type :
180
- node = _type ._type ._meta
183
+ node = _type .type . _meta if hasattr ( _type . type , "_meta" ) else _type . type . _of_type ._meta
181
184
if "id" in node .fields and not issubclass (
182
185
node .model , (mongoengine .EmbeddedDocument ,)
183
186
):
184
187
r .update ({kv [0 ]: node .fields ["id" ]._type .of_type ()})
188
+
185
189
return r
186
190
187
191
return reduce (get_reference_field , self .fields .items (), {})
188
192
189
193
@property
190
194
def fields (self ):
195
+ self ._type = get_type (self ._type )
191
196
return self ._type ._meta .fields
192
197
193
- def get_queryset (self , model , info , only_fields = list (), ** args ):
198
+ def get_queryset (self , model , info , required_fields = list (), skip = None , limit = None , reversed = False , ** args ):
194
199
if args :
195
200
reference_fields = get_model_reference_fields (self .model )
196
201
hydrated_references = {}
197
202
for arg_name , arg in args .copy ().items ():
198
- if arg_name in reference_fields :
199
- reference_obj = get_node_from_global_id (
200
- reference_fields [arg_name ], info , args .pop (arg_name )
201
- )
203
+ if arg_name in reference_fields and not isinstance (arg ,
204
+ mongoengine .base .metaclasses .TopLevelDocumentMetaclass ):
205
+ try :
206
+ reference_obj = reference_fields [arg_name ].document_type (pk = from_global_id (arg )[1 ])
207
+ except TypeError :
208
+ reference_obj = reference_fields [arg_name ].document_type (pk = arg )
202
209
hydrated_references [arg_name ] = reference_obj
210
+ elif arg_name == "id" :
211
+ hydrated_references ["id" ] = from_global_id (args .pop ("id" , None ))[1 ]
203
212
args .update (hydrated_references )
204
213
205
214
if self ._get_queryset :
@@ -208,72 +217,120 @@ def get_queryset(self, model, info, only_fields=list(), **args):
208
217
return queryset_or_filters
209
218
else :
210
219
args .update (queryset_or_filters )
220
+ if limit is not None :
221
+ if reversed :
222
+ order_by = ""
223
+ if self .order_by :
224
+ order_by = self .order_by + ",-pk"
225
+ else :
226
+ order_by = "-pk"
227
+ return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (order_by ).skip (
228
+ skip if skip else 0 ).limit (limit )
229
+ else :
230
+ return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (self .order_by ).skip (
231
+ skip if skip else 0 ).limit (limit )
232
+ elif skip is not None :
233
+ if reversed :
234
+ order_by = ""
235
+ if self .order_by :
236
+ order_by = self .order_by + ",-pk"
237
+ else :
238
+ order_by = "-pk"
239
+ return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (order_by ).skip (
240
+ skip )
241
+ else :
242
+ return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (self .order_by ).skip (
243
+ skip )
244
+ return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (self .order_by )
211
245
212
- return model .objects (** args ).no_dereference ().only (* only_fields ).order_by (self .order_by )
213
-
214
- def default_resolver (self , _root , info , only_fields = list (), ** args ):
246
+ def default_resolver (self , _root , info , required_fields = list (), ** args ):
215
247
args = args or {}
216
248
217
249
if _root is not None :
218
250
field_name = to_snake_case (info .field_name )
219
- if getattr (_root , field_name , []) is not None :
220
- args ["pk__in" ] = [r .id for r in getattr (_root , field_name , [])]
221
-
222
- connection_args = {
223
- "first" : args .pop ("first" , None ),
224
- "last" : args .pop ("last" , None ),
225
- "before" : args .pop ("before" , None ),
226
- "after" : args .pop ("after" , None ),
227
- }
251
+ if field_name in _root ._fields_ordered :
252
+ if getattr (_root , field_name , []) is not None :
253
+ args ["pk__in" ] = [r .id for r in getattr (_root , field_name , [])]
228
254
229
255
_id = args .pop ('id' , None )
230
256
231
257
if _id is not None :
232
258
args ['pk' ] = from_global_id (_id )[- 1 ]
233
-
259
+ iterables = []
260
+ list_length = 0
261
+ skip = 0
262
+ count = 0
263
+ limit = None
264
+ reverse = False
234
265
if callable (getattr (self .model , "objects" , None )):
235
- iterables = self .get_queryset (self .model , info , only_fields , ** args )
236
- if isinstance (info , ResolveInfo ):
237
- if not info .context :
238
- info .context = Context ()
239
- info .context .queryset = iterables
240
- list_length = iterables .count ()
241
- else :
242
- iterables = []
243
- list_length = 0
244
-
245
- connection = connection_from_list_slice (
246
- list_slice = iterables ,
247
- args = connection_args ,
248
- list_length = list_length ,
249
- list_slice_length = list_length ,
250
- connection_type = self .type ,
251
- edge_type = self .type .Edge ,
252
- pageinfo_type = graphene .PageInfo ,
253
- )
266
+ first = args .pop ("first" , None )
267
+ after = cursor_to_offset (args .pop ("after" , None ))
268
+ last = args .pop ("last" , None )
269
+ before = cursor_to_offset (args .pop ("before" , None ))
270
+ if "pk__in" in args and args ["pk__in" ]:
271
+ count = len (args ["pk__in" ])
272
+ skip , limit , reverse = find_skip_and_limit (first = first , last = last , after = after , before = before ,
273
+ count = count )
274
+ if limit :
275
+ if reverse :
276
+ args ["pk__in" ] = args ["pk__in" ][::- 1 ][skip :skip + limit ]
277
+ else :
278
+ args ["pk__in" ] = args ["pk__in" ][skip :skip + limit ]
279
+ elif skip :
280
+ args ["pk__in" ] = args ["pk__in" ][skip :]
281
+ iterables = self .get_queryset (self .model , info , required_fields , ** args )
282
+ list_length = len (iterables )
283
+ if isinstance (info , ResolveInfo ):
284
+ if not info .context :
285
+ info .context = Context ()
286
+ info .context .queryset = self .get_queryset (self .model , info , required_fields , ** args )
287
+ elif _root is None :
288
+ count = self .get_queryset (self .model , info , required_fields , ** args ).count ()
289
+ if count != 0 :
290
+ skip , limit , reverse = find_skip_and_limit (first = first , after = after , last = last , before = before ,
291
+ count = count )
292
+ iterables = self .get_queryset (self .model , info , required_fields , skip , limit , reverse , ** args )
293
+ list_length = len (iterables )
294
+ if isinstance (info , ResolveInfo ):
295
+ if not info .context :
296
+ info .context = Context ()
297
+ info .context .queryset = self .get_queryset (self .model , info , required_fields , ** args )
298
+ has_next_page = True if (0 if limit is None else limit ) + (0 if skip is None else skip ) < count else False
299
+ has_previous_page = True if skip else False
300
+ if reverse :
301
+ iterables = list (iterables )
302
+ iterables .reverse ()
303
+ skip = limit
304
+ connection = connection_from_iterables (edges = iterables , start_offset = skip ,
305
+ has_previous_page = has_previous_page ,
306
+ has_next_page = has_next_page ,
307
+ connection_type = self .type ,
308
+ edge_type = self .type .Edge ,
309
+ pageinfo_type = graphene .PageInfo )
310
+
254
311
connection .iterable = iterables
255
312
connection .list_length = list_length
256
313
return connection
257
314
258
315
def chained_resolver (self , resolver , is_partial , root , info , ** args ):
259
- only_fields = list ()
260
- for field in self .only_fields :
316
+ required_fields = list ()
317
+ for field in self .required_fields :
261
318
if field in self .model ._fields_ordered :
262
- only_fields .append (field )
319
+ required_fields .append (field )
263
320
for field in get_query_fields (info ):
264
321
if to_snake_case (field ) in self .model ._fields_ordered :
265
- only_fields .append (to_snake_case (field ))
322
+ required_fields .append (to_snake_case (field ))
266
323
if not bool (args ) or not is_partial :
267
324
if isinstance (self .model , mongoengine .Document ) or isinstance (self .model ,
268
325
mongoengine .base .metaclasses .TopLevelDocumentMetaclass ):
269
326
args_copy = args .copy ()
270
327
for arg_name , arg in args .copy ().items ():
271
- if arg_name not in self .model ._fields_ordered :
328
+ if arg_name not in self .model ._fields_ordered + tuple ( self . filter_args . keys ()) :
272
329
args_copy .pop (arg_name )
273
330
if isinstance (info , ResolveInfo ):
274
331
if not info .context :
275
332
info .context = Context ()
276
- info .context .queryset = self .get_queryset (self .model , info , only_fields , ** args_copy )
333
+ info .context .queryset = self .get_queryset (self .model , info , required_fields , ** args_copy )
277
334
# XXX: Filter nested args
278
335
resolved = resolver (root , info , ** args )
279
336
if resolved is not None :
@@ -282,9 +339,17 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
282
339
return resolved
283
340
elif not isinstance (resolved [0 ], DBRef ):
284
341
return resolved
342
+ elif isinstance (resolved , QuerySet ):
343
+ args .update (resolved ._query )
344
+ args_copy = args .copy ()
345
+ for arg_name , arg in args .copy ().items ():
346
+ if arg_name not in self .model ._fields_ordered + ('first' , 'last' , 'before' , 'after' ) + tuple (
347
+ self .filter_args .keys ()):
348
+ args_copy .pop (arg_name )
349
+ return self .default_resolver (root , info , required_fields , ** args_copy )
285
350
else :
286
351
return resolved
287
- return self .default_resolver (root , info , only_fields , ** args )
352
+ return self .default_resolver (root , info , required_fields , ** args )
288
353
289
354
@classmethod
290
355
def connection_resolver (cls , resolver , connection_type , root , info , ** args ):
0 commit comments