14
14
from .advanced_types import PointFieldType , MultiPolygonFieldType
15
15
from .converter import convert_mongoengine_field , MongoEngineConversionError
16
16
from .registry import get_global_registry
17
- from .utils import get_model_reference_fields , global_id_via_node
17
+ from .utils import get_model_reference_fields , get_node_from_global_id
18
18
19
19
20
20
class MongoengineConnectionField (ConnectionField ):
@@ -113,18 +113,24 @@ def fields(self):
113
113
return self ._type ._meta .fields
114
114
115
115
def get_queryset (self , model , info , ** args ):
116
+
117
+ if args :
118
+ reference_fields = get_model_reference_fields (self .model )
119
+ hydrated_references = {}
120
+ for arg_name , arg in args .copy ().items ():
121
+ if arg_name in reference_fields :
122
+ reference_obj = get_node_from_global_id (reference_fields [arg_name ], info , args .pop (arg_name ))
123
+ hydrated_references [arg_name ] = reference_obj
124
+ args .update (hydrated_references )
116
125
if self ._get_queryset :
117
126
queryset_or_filters = self ._get_queryset (model , info , ** args )
118
127
if isinstance (queryset_or_filters , mongoengine .QuerySet ):
119
128
return queryset_or_filters
120
129
else :
121
- return model . objects ( ** queryset_or_filters )
122
- return model .objects ()
130
+ args . update ( queryset_or_filters )
131
+ return model .objects (** args )
123
132
124
133
def default_resolver (self , _root , info , ** args ):
125
- if not callable (getattr (self .model , 'objects' , None )):
126
- return [], 0
127
-
128
134
args = args or {}
129
135
130
136
connection_args = {
@@ -134,29 +140,22 @@ def default_resolver(self, _root, info, **args):
134
140
'after' : args .pop ('after' , None )
135
141
}
136
142
137
- objs = self .get_queryset (self .model , info , ** args )
138
-
139
- if args :
140
- reference_fields = get_model_reference_fields (self .model )
141
- reference_args = {}
142
- for arg_name , arg in args .copy ().items ():
143
- if arg_name in reference_fields :
144
- reference_model = self .model ._fields [arg_name ]
145
- pk = global_id_via_node (self .node_type , args .pop (arg_name ))[- 1 ]
146
- reference_obj = reference_model .document_type_obj .objects (pk = pk ).get ()
147
- reference_args [arg_name ] = reference_obj
148
-
149
- args .update (reference_args )
150
- _id = args .pop ('id' , None )
151
- if _id is not None :
152
- args ['pk' ] = global_id_via_node (self .node_type , _id )[- 1 ]
143
+ _id = args .pop ('id' , None )
153
144
154
- objs = objs .filter (** args )
145
+ if _id is not None :
146
+ objs = [get_node_from_global_id (self .node_type , info , _id )]
147
+ list_length = 1
148
+ elif callable (getattr (self .model , 'objects' , None )):
149
+ objs = self .get_queryset (self .model , info , ** args )
150
+ list_length = objs .count ()
151
+ else :
152
+ objs = []
153
+ list_length = 0
155
154
156
155
connection = connection_from_list_slice (
157
156
list_slice = objs ,
158
157
args = connection_args ,
159
- list_length = objs . count () ,
158
+ list_length = list_length ,
160
159
connection_type = self .type ,
161
160
edge_type = self .type .Edge ,
162
161
pageinfo_type = PageInfo ,
0 commit comments