16
16
from graphene .types .argument import to_arguments
17
17
from graphene .types .dynamic import Dynamic
18
18
from graphene .types .structures import Structure
19
- from graphql_relay .connection .arrayconnection import cursor_to_offset
20
- from mongoengine import QuerySet
19
+ from graphql_relay .connection .arrayconnection import connection_from_list_slice
21
20
22
21
from .advanced_types import (
23
22
FileFieldType ,
27
26
)
28
27
from .converter import convert_mongoengine_field , MongoEngineConversionError
29
28
from .registry import get_global_registry
30
- from .utils import get_model_reference_fields , get_node_from_global_id , get_query_fields , find_skip_and_limit , \
31
- connection_from_iterables
29
+ from .utils import get_model_reference_fields , get_node_from_global_id , get_query_fields
32
30
33
31
34
32
class MongoengineConnectionField (ConnectionField ):
@@ -178,7 +176,7 @@ def get_reference_field(r, kv):
178
176
if callable (getattr (field , "get_type" , None )):
179
177
_type = field .get_type ()
180
178
if _type :
181
- node = _type ._type ._meta
179
+ node = _type .type . _meta if hasattr ( _type . type , "_meta" ) else _type . type . _of_type ._meta
182
180
if "id" in node .fields and not issubclass (
183
181
node .model , (mongoengine .EmbeddedDocument ,)
184
182
):
@@ -192,12 +190,12 @@ def fields(self):
192
190
self ._type = get_type (self ._type )
193
191
return self ._type ._meta .fields
194
192
195
- def get_queryset (self , model , info , required_fields = list (), skip = None , limit = None , reversed = False , ** args ):
193
+ def get_queryset (self , model , info , required_fields = list (), ** args ):
196
194
if args :
197
195
reference_fields = get_model_reference_fields (self .model )
198
196
hydrated_references = {}
199
197
for arg_name , arg in args .copy ().items ():
200
- if arg_name in reference_fields :
198
+ if arg_name in reference_fields and isinstance ( arg , str ) :
201
199
reference_obj = get_node_from_global_id (
202
200
reference_fields [arg_name ], info , args .pop (arg_name )
203
201
)
@@ -210,94 +208,50 @@ def get_queryset(self, model, info, required_fields=list(), skip=None, limit=Non
210
208
return queryset_or_filters
211
209
else :
212
210
args .update (queryset_or_filters )
213
- if limit is not None :
214
- if reversed :
215
- order_by = ""
216
- if self .order_by :
217
- order_by = self .order_by + ",-pk"
218
- else :
219
- order_by = "-pk"
220
- return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (order_by ).skip (
221
- skip if skip else 0 ).limit (limit )
222
- else :
223
- return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (self .order_by ).skip (
224
- skip if skip else 0 ).limit (limit )
225
- elif skip is not None :
226
- if reversed :
227
- order_by = ""
228
- if self .order_by :
229
- order_by = self .order_by + ",-pk"
230
- else :
231
- order_by = "-pk"
232
- return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (order_by ).skip (
233
- skip )
234
- else :
235
- return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (self .order_by ).skip (
236
- skip )
211
+
237
212
return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (self .order_by )
238
213
239
214
def default_resolver (self , _root , info , required_fields = list (), ** args ):
240
215
args = args or {}
216
+
241
217
if _root is not None :
242
218
field_name = to_snake_case (info .field_name )
243
219
if field_name in _root ._fields_ordered :
244
220
if getattr (_root , field_name , []) is not None :
245
221
args ["pk__in" ] = [r .id for r in getattr (_root , field_name , [])]
246
222
223
+ connection_args = {
224
+ "first" : args .pop ("first" , None ),
225
+ "last" : args .pop ("last" , None ),
226
+ "before" : args .pop ("before" , None ),
227
+ "after" : args .pop ("after" , None ),
228
+ }
229
+
247
230
_id = args .pop ('id' , None )
231
+
248
232
if _id is not None :
249
233
args ['pk' ] = from_global_id (_id )[- 1 ]
250
- iterables = []
251
- list_length = 0
252
- skip = 0
253
- count = 0
254
- limit = None
255
- reverse = False
234
+
256
235
if callable (getattr (self .model , "objects" , None )):
257
- first = args .pop ("first" , None )
258
- after = cursor_to_offset (args .pop ("after" , None ))
259
- last = args .pop ("last" , None )
260
- before = cursor_to_offset (args .pop ("before" , None ))
261
- if "pk__in" in args and args ["pk__in" ]:
262
- count = len (args ["pk__in" ])
263
- skip , limit , reverse = find_skip_and_limit (first = first , last = last , after = after , before = before ,
264
- count = count )
265
- if limit :
266
- if reverse :
267
- args ["pk__in" ] = args ["pk__in" ][::- 1 ][skip :skip + limit ]
268
- else :
269
- args ["pk__in" ] = args ["pk__in" ][skip :skip + limit ]
270
- elif skip :
271
- args ["pk__in" ] = args ["pk__in" ][skip :]
272
- iterables = self .get_queryset (self .model , info , required_fields , ** args )
273
- list_length = len (iterables )
274
- if isinstance (info , ResolveInfo ):
275
- if not info .context :
276
- info .context = Context ()
277
- info .context .queryset = self .get_queryset (self .model , info , required_fields , ** args )
278
- else :
279
- count = self .get_queryset (self .model , info , required_fields , ** args ).count ()
280
- if count != 0 :
281
- skip , limit , reverse = find_skip_and_limit (first = first , after = after , last = last , before = before ,
282
- count = count )
283
- iterables = self .get_queryset (self .model , info , required_fields , skip , limit , reverse , ** args )
284
- list_length = len (iterables )
285
- if isinstance (info , ResolveInfo ):
286
- if not info .context :
287
- info .context = Context ()
288
- info .context .queryset = self .get_queryset (self .model , info , required_fields , ** args )
289
- has_next_page = True if (0 if limit is None else limit ) + (0 if skip is None else skip ) < count else False
290
- has_previous_page = True if skip else False
291
- if reverse :
292
- iterables = list (iterables )
293
- iterables .reverse ()
294
- skip = limit
295
- connection = connection_from_iterables (edges = iterables , start_offset = skip ,
296
- has_previous_page = has_previous_page ,
297
- has_next_page = has_next_page ,
298
- connection_type = self .type ,
299
- edge_type = self .type .Edge ,
300
- pageinfo_type = graphene .PageInfo )
236
+ iterables = self .get_queryset (self .model , info , required_fields , ** args )
237
+ if isinstance (info , ResolveInfo ):
238
+ if not info .context :
239
+ info .context = Context ()
240
+ info .context .queryset = iterables
241
+ list_length = iterables .count ()
242
+ else :
243
+ iterables = []
244
+ list_length = 0
245
+
246
+ connection = connection_from_list_slice (
247
+ list_slice = iterables ,
248
+ args = connection_args ,
249
+ list_length = list_length ,
250
+ list_slice_length = list_length ,
251
+ connection_type = self .type ,
252
+ edge_type = self .type .Edge ,
253
+ pageinfo_type = graphene .PageInfo ,
254
+ )
301
255
connection .iterable = iterables
302
256
connection .list_length = list_length
303
257
return connection
@@ -329,9 +283,6 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
329
283
return resolved
330
284
elif not isinstance (resolved [0 ], DBRef ):
331
285
return resolved
332
- elif isinstance (resolved , QuerySet ):
333
- args .update (resolved ._query )
334
- return self .default_resolver (root , info , required_fields , ** args )
335
286
else :
336
287
return resolved
337
288
return self .default_resolver (root , info , required_fields , ** args )
0 commit comments