15
15
from graphene .types .argument import to_arguments
16
16
from graphene .types .dynamic import Dynamic
17
17
from graphene .types .structures import Structure
18
- from graphql_relay .connection .arrayconnection import connection_from_list_slice
18
+ from graphql_relay .connection .arrayconnection import cursor_to_offset
19
+ from mongoengine import QuerySet
19
20
20
21
from .advanced_types import (
21
22
FileFieldType ,
25
26
)
26
27
from .converter import convert_mongoengine_field , MongoEngineConversionError
27
28
from .registry import get_global_registry
28
- from .utils import get_model_reference_fields , get_node_from_global_id , get_query_fields
29
+ from .utils import get_model_reference_fields , get_node_from_global_id , get_query_fields , find_skip_and_limit , \
30
+ connection_from_iterables
29
31
30
32
31
33
class MongoengineConnectionField (ConnectionField ):
@@ -188,7 +190,7 @@ def get_reference_field(r, kv):
188
190
def fields (self ):
189
191
return self ._type ._meta .fields
190
192
191
- def get_queryset (self , model , info , required_fields = list (), ** args ):
193
+ def get_queryset (self , model , info , required_fields = list (), skip = None , limit = None , reversed = False , ** args ):
192
194
if args :
193
195
reference_fields = get_model_reference_fields (self .model )
194
196
hydrated_references = {}
@@ -206,7 +208,30 @@ def get_queryset(self, model, info, required_fields=list(), **args):
206
208
return queryset_or_filters
207
209
else :
208
210
args .update (queryset_or_filters )
209
-
211
+ if limit is not None :
212
+ if reversed :
213
+ order_by = ""
214
+ if self .order_by :
215
+ order_by = self .order_by + ",-pk"
216
+ else :
217
+ order_by = "-pk"
218
+ return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (order_by ).skip (
219
+ skip if skip else 0 ).limit (limit )
220
+ else :
221
+ return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (self .order_by ).skip (
222
+ skip if skip else 0 ).limit (limit )
223
+ elif skip is not None :
224
+ if reversed :
225
+ order_by = ""
226
+ if self .order_by :
227
+ order_by = self .order_by + ",-pk"
228
+ else :
229
+ order_by = "-pk"
230
+ return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (order_by ).skip (
231
+ skip )
232
+ else :
233
+ return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (self .order_by ).skip (
234
+ skip )
210
235
return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (self .order_by )
211
236
212
237
def default_resolver (self , _root , info , required_fields = list (), ** args ):
@@ -217,38 +242,66 @@ def default_resolver(self, _root, info, required_fields=list(), **args):
217
242
if getattr (_root , field_name , []) is not None :
218
243
args ["pk__in" ] = [r .id for r in getattr (_root , field_name , [])]
219
244
220
- connection_args = {
221
- "first" : args .pop ("first" , None ),
222
- "last" : args .pop ("last" , None ),
223
- "before" : args .pop ("before" , None ),
224
- "after" : args .pop ("after" , None ),
225
- }
226
-
227
245
_id = args .pop ('id' , None )
228
246
229
247
if _id is not None :
230
248
args ['pk' ] = from_global_id (_id )[- 1 ]
231
-
249
+ iterables = []
250
+ list_length = 0
251
+ skip = 0
252
+ count = 0
253
+ limit = None
254
+ reverse = False
232
255
if callable (getattr (self .model , "objects" , None )):
233
- iterables = self .get_queryset (self .model , info , required_fields , ** args )
234
- if isinstance (info , ResolveInfo ):
235
- if not info .context :
236
- info .context = Context ()
237
- info .context .queryset = iterables
238
- list_length = iterables .count ()
239
- else :
240
- iterables = []
241
- list_length = 0
242
-
243
- connection = connection_from_list_slice (
244
- list_slice = iterables ,
245
- args = connection_args ,
246
- list_length = list_length ,
247
- list_slice_length = list_length ,
248
- connection_type = self .type ,
249
- edge_type = self .type .Edge ,
250
- pageinfo_type = graphene .PageInfo ,
251
- )
256
+ first = args .pop ("first" , None )
257
+ after = cursor_to_offset (args .pop ("after" , None ))
258
+ last = args .pop ("last" , None )
259
+ before = cursor_to_offset (args .pop ("before" , None ))
260
+ if after is not None :
261
+ has_previous_page = after > 0
262
+ elif (before is not None and last is not None ):
263
+ has_previous_page = before - last <= 0
264
+ if "pk__in" in args and args ["pk__in" ]:
265
+ count = len (args ["pk__in" ])
266
+ skip , limit , reverse = find_skip_and_limit (first = first , last = last , after = after , before = before ,
267
+ count = count )
268
+ if limit :
269
+ if reverse :
270
+ args ["pk__in" ] = args ["pk__in" ][::- 1 ][skip :skip + limit ]
271
+ else :
272
+ args ["pk__in" ] = args ["pk__in" ][skip :skip + limit ]
273
+ elif skip :
274
+ args ["pk__in" ] = args ["pk__in" ][skip :]
275
+ iterables = self .get_queryset (self .model , info , required_fields , ** args )
276
+ list_length = len (iterables )
277
+ if isinstance (info , ResolveInfo ):
278
+ if not info .context :
279
+ info .context = Context ()
280
+ info .context .queryset = self .get_queryset (self .model , info , required_fields , ** args )
281
+ elif _root is None :
282
+ count = self .get_queryset (self .model , info , required_fields , ** args ).count ()
283
+ if count != 0 :
284
+ skip , limit , reverse = find_skip_and_limit (first = first , after = after , last = last , before = before ,
285
+ count = count )
286
+ iterables = self .get_queryset (self .model , info , required_fields , skip , limit , reverse , ** args )
287
+ list_length = len (iterables )
288
+ if isinstance (info , ResolveInfo ):
289
+ if not info .context :
290
+ info .context = Context ()
291
+ info .context .queryset = self .get_queryset (self .model , info , required_fields , ** args )
292
+ has_next_page = True if (0 if limit is None else limit ) + (0 if skip is None else skip ) < count else False
293
+ has_previous_page = True if skip else False
294
+ if reverse :
295
+ iterables = list (iterables )
296
+ iterables .reverse ()
297
+ skip = limit
298
+ connection = connection_from_iterables (edges = iterables , start_offset = skip ,
299
+ has_previous_page = has_previous_page ,
300
+ has_next_page = has_next_page ,
301
+ connection_type = self .type ,
302
+ edge_type = self .type .Edge ,
303
+ pageinfo_type = graphene .PageInfo )
304
+
252
305
connection .iterable = iterables
253
306
connection .list_length = list_length
254
307
return connection
@@ -280,6 +333,9 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
280
333
return resolved
281
334
elif not isinstance (resolved [0 ], DBRef ):
282
335
return resolved
336
+ elif isinstance (resolved , QuerySet ):
337
+ args .update (resolved ._query )
338
+ return self .default_resolver (root , info , required_fields , ** args )
283
339
else :
284
340
return resolved
285
341
return self .default_resolver (root , info , required_fields , ** args )
0 commit comments