16
16
from graphene .types .argument import to_arguments
17
17
from graphene .types .dynamic import Dynamic
18
18
from graphene .types .structures import Structure
19
- from graphql_relay .connection .arrayconnection import connection_from_list_slice
19
+ from graphql_relay .connection .arrayconnection import cursor_to_offset
20
+ from mongoengine import QuerySet
20
21
21
22
from .advanced_types import (
22
23
FileFieldType ,
26
27
)
27
28
from .converter import convert_mongoengine_field , MongoEngineConversionError
28
29
from .registry import get_global_registry
29
- from .utils import get_model_reference_fields , get_node_from_global_id , get_query_fields
30
+ from .utils import get_model_reference_fields , get_node_from_global_id , get_query_fields , find_skip_and_limit , \
31
+ connection_from_iterables
30
32
31
33
32
34
class MongoengineConnectionField (ConnectionField ):
@@ -190,7 +192,7 @@ def fields(self):
190
192
self ._type = get_type (self ._type )
191
193
return self ._type ._meta .fields
192
194
193
- def get_queryset (self , model , info , required_fields = list (), ** args ):
195
+ def get_queryset (self , model , info , required_fields = list (), skip = None , limit = None , reversed = False , ** args ):
194
196
if args :
195
197
reference_fields = get_model_reference_fields (self .model )
196
198
hydrated_references = {}
@@ -208,7 +210,30 @@ def get_queryset(self, model, info, required_fields=list(), **args):
208
210
return queryset_or_filters
209
211
else :
210
212
args .update (queryset_or_filters )
211
-
213
+ if limit is not None :
214
+ if reversed :
215
+ order_by = ""
216
+ if self .order_by :
217
+ order_by = self .order_by + ",-pk"
218
+ else :
219
+ order_by = "-pk"
220
+ return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (order_by ).skip (
221
+ skip if skip else 0 ).limit (limit )
222
+ else :
223
+ return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (self .order_by ).skip (
224
+ skip if skip else 0 ).limit (limit )
225
+ elif skip is not None :
226
+ if reversed :
227
+ order_by = ""
228
+ if self .order_by :
229
+ order_by = self .order_by + ",-pk"
230
+ else :
231
+ order_by = "-pk"
232
+ return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (order_by ).skip (
233
+ skip )
234
+ else :
235
+ return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (self .order_by ).skip (
236
+ skip )
212
237
return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (self .order_by )
213
238
214
239
def default_resolver (self , _root , info , required_fields = list (), ** args ):
@@ -220,38 +245,62 @@ def default_resolver(self, _root, info, required_fields=list(), **args):
220
245
if getattr (_root , field_name , []) is not None :
221
246
args ["pk__in" ] = [r .id for r in getattr (_root , field_name , [])]
222
247
223
- connection_args = {
224
- "first" : args .pop ("first" , None ),
225
- "last" : args .pop ("last" , None ),
226
- "before" : args .pop ("before" , None ),
227
- "after" : args .pop ("after" , None ),
228
- }
229
-
230
248
_id = args .pop ('id' , None )
231
249
232
250
if _id is not None :
233
251
args ['pk' ] = from_global_id (_id )[- 1 ]
234
-
252
+ iterables = []
253
+ list_length = 0
254
+ skip = 0
255
+ count = 0
256
+ limit = None
257
+ reverse = False
235
258
if callable (getattr (self .model , "objects" , None )):
236
- iterables = self .get_queryset (self .model , info , required_fields , ** args )
237
- if isinstance (info , ResolveInfo ):
238
- if not info .context :
239
- info .context = Context ()
240
- info .context .queryset = iterables
241
- list_length = iterables .count ()
242
- else :
243
- iterables = []
244
- list_length = 0
245
-
246
- connection = connection_from_list_slice (
247
- list_slice = iterables ,
248
- args = connection_args ,
249
- list_length = list_length ,
250
- list_slice_length = list_length ,
251
- connection_type = self .type ,
252
- edge_type = self .type .Edge ,
253
- pageinfo_type = graphene .PageInfo ,
254
- )
259
+ first = args .pop ("first" , None )
260
+ after = cursor_to_offset (args .pop ("after" , None ))
261
+ last = args .pop ("last" , None )
262
+ before = cursor_to_offset (args .pop ("before" , None ))
263
+ if "pk__in" in args and args ["pk__in" ]:
264
+ count = len (args ["pk__in" ])
265
+ skip , limit , reverse = find_skip_and_limit (first = first , last = last , after = after , before = before ,
266
+ count = count )
267
+ if limit :
268
+ if reverse :
269
+ args ["pk__in" ] = args ["pk__in" ][::- 1 ][skip :skip + limit ]
270
+ else :
271
+ args ["pk__in" ] = args ["pk__in" ][skip :skip + limit ]
272
+ elif skip :
273
+ args ["pk__in" ] = args ["pk__in" ][skip :]
274
+ iterables = self .get_queryset (self .model , info , required_fields , ** args )
275
+ list_length = len (iterables )
276
+ if isinstance (info , ResolveInfo ):
277
+ if not info .context :
278
+ info .context = Context ()
279
+ info .context .queryset = self .get_queryset (self .model , info , required_fields , ** args )
280
+ elif _root is None :
281
+ count = self .get_queryset (self .model , info , required_fields , ** args ).count ()
282
+ if count != 0 :
283
+ skip , limit , reverse = find_skip_and_limit (first = first , after = after , last = last , before = before ,
284
+ count = count )
285
+ iterables = self .get_queryset (self .model , info , required_fields , skip , limit , reverse , ** args )
286
+ list_length = len (iterables )
287
+ if isinstance (info , ResolveInfo ):
288
+ if not info .context :
289
+ info .context = Context ()
290
+ info .context .queryset = self .get_queryset (self .model , info , required_fields , ** args )
291
+ has_next_page = True if (0 if limit is None else limit ) + (0 if skip is None else skip ) < count else False
292
+ has_previous_page = True if skip else False
293
+ if reverse :
294
+ iterables = list (iterables )
295
+ iterables .reverse ()
296
+ skip = limit
297
+ connection = connection_from_iterables (edges = iterables , start_offset = skip ,
298
+ has_previous_page = has_previous_page ,
299
+ has_next_page = has_next_page ,
300
+ connection_type = self .type ,
301
+ edge_type = self .type .Edge ,
302
+ pageinfo_type = graphene .PageInfo )
303
+
255
304
connection .iterable = iterables
256
305
connection .list_length = list_length
257
306
return connection
@@ -283,6 +332,9 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
283
332
return resolved
284
333
elif not isinstance (resolved [0 ], DBRef ):
285
334
return resolved
335
+ elif isinstance (resolved , QuerySet ):
336
+ args .update (resolved ._query )
337
+ return self .default_resolver (root , info , required_fields , ** args )
286
338
else :
287
339
return resolved
288
340
return self .default_resolver (root , info , required_fields , ** args )
0 commit comments