@@ -107,40 +107,52 @@ def convert_field_to_list(field, registry=None):
107
107
if isinstance (base_type , graphene .Field ):
108
108
if isinstance (field .field , mongoengine .GenericReferenceField ):
109
109
def get_reference_objects (* args , ** kwargs ):
110
- if args [0 ][1 ]:
111
- document = get_document (args [0 ][0 ])
112
- document_field = mongoengine .ReferenceField (document )
113
- document_field = convert_mongoengine_field (document_field , registry )
114
- document_field_type = document_field .get_type ().type ._meta .name
115
- required_fields = [to_snake_case (i ) for i in
116
- get_query_fields (args [0 ][3 ][0 ])[document_field_type ].keys ()]
117
- return document .objects ().no_dereference ().only (* required_fields ).filter (pk__in = args [0 ][1 ])
118
- else :
119
- return []
110
+ document = get_document (args [0 ][0 ])
111
+ document_field = mongoengine .ReferenceField (document )
112
+ document_field = convert_mongoengine_field (document_field , registry )
113
+ document_field_type = document_field .get_type ().type ._meta .name
114
+ required_fields = [to_snake_case (i ) for i in
115
+ get_query_fields (args [0 ][3 ][0 ])[document_field_type ].keys ()]
116
+ return document .objects ().no_dereference ().only (* required_fields ).filter (pk__in = args [0 ][1 ])
117
+
118
+ def get_non_querying_object (* args , ** kwargs ):
119
+ model = get_document (args [0 ][0 ])
120
+ return [model (pk = each ) for each in args [0 ][1 ]]
120
121
121
122
def reference_resolver (root , * args , ** kwargs ):
122
- choice_to_resolve = dict ()
123
123
to_resolve = getattr (root , field .name or field .db_name )
124
124
if to_resolve :
125
+ choice_to_resolve = dict ()
126
+ querying_union_types = list (get_query_fields (args [0 ]).keys ())
127
+ if '__typename' in querying_union_types :
128
+ querying_union_types .remove ('__typename' )
129
+ to_resolve_models = list ()
130
+ for each in querying_union_types :
131
+ to_resolve_models .append (registry ._registry_string_map [each ])
125
132
for each in to_resolve :
126
133
if each ['_cls' ] not in choice_to_resolve :
127
134
choice_to_resolve [each ['_cls' ]] = list ()
128
135
choice_to_resolve [each ['_cls' ]].append (each ["_ref" ].id )
129
-
130
136
pool = ThreadPoolExecutor (5 )
131
137
futures = list ()
132
138
for model , object_id_list in choice_to_resolve .items ():
133
- futures .append (pool .submit (get_reference_objects , (model , object_id_list , registry , args )))
139
+ if model in to_resolve_models :
140
+ futures .append (pool .submit (get_reference_objects , (model , object_id_list , registry , args )))
141
+ else :
142
+ futures .append (
143
+ pool .submit (get_non_querying_object , (model , object_id_list , registry , args )))
134
144
result = list ()
135
145
for x in as_completed (futures ):
136
146
result += x .result ()
137
147
to_resolve_object_ids = [each ["_ref" ].id for each in to_resolve ]
138
- result_to_resolve_object_ids = [each .id for each in result ]
148
+ result_object_ids = list ()
149
+ for each in result :
150
+ result_object_ids .append (each .id )
139
151
ordered_result = list ()
140
152
for each in to_resolve_object_ids :
141
- ordered_result .append (result [result_to_resolve_object_ids .index (each )])
153
+ ordered_result .append (result [result_object_ids .index (each )])
142
154
return ordered_result
143
- return []
155
+ return None
144
156
145
157
return graphene .List (
146
158
base_type ._type ,
@@ -207,17 +219,20 @@ def convert_field_to_union(field, registry=None):
207
219
_union = type (name , (graphene .Union ,), {"Meta" : Meta })
208
220
209
221
def reference_resolver (root , * args , ** kwargs ):
210
- dereferenced = getattr (root , field .name or field .db_name )
211
- if dereferenced :
212
- document = get_document (dereferenced ["_cls" ])
222
+ de_referenced = getattr (root , field .name or field .db_name )
223
+ if de_referenced :
224
+ document = get_document (de_referenced ["_cls" ])
213
225
document_field = mongoengine .ReferenceField (document )
214
226
document_field = convert_mongoengine_field (document_field , registry )
215
227
_type = document_field .get_type ().type
216
- only_fields = _type ._meta .only_fields .split ("," ) if isinstance (_type ._meta .only_fields ,
217
- str ) else list ()
218
- return document .objects ().no_dereference ().only (* list (
219
- set (only_fields + [to_snake_case (i ) for i in get_query_fields (args [0 ])[_type ._meta .name ].keys ()]))).get (
220
- pk = dereferenced ["_ref" ].id )
228
+ querying_types = list (get_query_fields (args [0 ]).keys ())
229
+ _type = document_field .get_type ().type
230
+ if _type .__name__ in querying_types :
231
+ return document .objects ().no_dereference ().only (* list (
232
+ set (list (_type ._meta .required_fields ) + [to_snake_case (i ) for i in
233
+ get_query_fields (args [0 ])[_type ._meta .name ].keys ()]))).get (
234
+ pk = de_referenced ["_ref" ].id )
235
+ return document
221
236
return None
222
237
223
238
if isinstance (field , mongoengine .GenericReferenceField ):
@@ -242,21 +257,19 @@ def reference_resolver(root, *args, **kwargs):
242
257
document = getattr (root , field .name or field .db_name )
243
258
if document :
244
259
_type = registry .get_type_for_model (field .document_type )
245
- required_fields = _type ._meta .required_fields .split ("," ) if isinstance (_type ._meta .required_fields ,
246
- str ) else list ()
247
260
return field .document_type .objects ().no_dereference ().only (
248
- * ((list (set (required_fields + [to_snake_case (i ) for i in get_query_fields (args [0 ]).keys ()]))))).get (
261
+ * ((list (set (list (_type ._meta .required_fields ) + [to_snake_case (i ) for i in
262
+ get_query_fields (args [0 ]).keys ()]))))).get (
249
263
pk = document .id )
250
264
return None
251
265
252
266
def cached_reference_resolver (root , * args , ** kwargs ):
253
267
if field :
254
268
_type = registry .get_type_for_model (field .document_type )
255
- required_fields = _type ._meta .required_fields .split ("," ) if isinstance (_type ._meta .required_fields ,
256
- str ) else list ()
257
269
return field .document_type .objects ().no_dereference ().only (
258
- * (list (set (required_fields + [to_snake_case (i ) for i in get_query_fields (args [0 ]).keys ()]))
259
- )).get (
270
+ * (list (set (
271
+ list (_type ._meta .required_fields ) + [to_snake_case (i ) for i in
272
+ get_query_fields (args [0 ]).keys ()])))).get (
260
273
pk = getattr (root , field .name or field .db_name ))
261
274
return None
262
275
@@ -290,10 +303,9 @@ def lazy_resolver(root, *args, **kwargs):
290
303
document = getattr (root , field .name or field .db_name )
291
304
if document :
292
305
_type = registry .get_type_for_model (document .document_type )
293
- required_fields = _type ._meta .required_fields .split ("," ) if isinstance (_type ._meta .required_fields ,
294
- str ) else list ()
295
306
return document .document_type .objects ().no_dereference ().only (
296
- * (list (set ((required_fields + [to_snake_case (i ) for i in get_query_fields (args [0 ]).keys ()]))))).get (
307
+ * (list (set ((list (_type ._meta .required_fields ) + [to_snake_case (i ) for i in
308
+ get_query_fields (args [0 ]).keys ()]))))).get (
297
309
pk = document .pk )
298
310
return None
299
311
0 commit comments