1111from .utils import (
1212 get_field_description ,
1313 get_query_fields ,
14+ get_queried_union_types ,
1415 get_field_is_required ,
1516 ExecutorEnum ,
1617 sync_to_async ,
@@ -154,7 +155,7 @@ def convert_field_to_list(field, registry=None, executor: ExecutorEnum = Executo
154155 if isinstance (field .field , mongoengine .GenericReferenceField ):
155156
156157 def get_reference_objects (* args , ** kwargs ):
157- document = get_document (args [0 ][ 0 ] )
158+ document = get_document (args [0 ])
158159 document_field = mongoengine .ReferenceField (document )
159160 document_field = convert_mongoengine_field (document_field , registry )
160161 document_field_type = document_field .get_type ().type
@@ -164,75 +165,70 @@ def get_reference_objects(*args, **kwargs):
164165 for key , values in document_field_type ._meta .filter_fields .items ():
165166 for each in values :
166167 filter_args .append (key + "__" + each )
167- for each in get_query_fields ( args [0 ][ 3 ][ 0 ])[ document_field_type . _meta . name ]. keys () :
168+ for each in args [4 ] :
168169 item = to_snake_case (each )
169170 if item in document ._fields_ordered + tuple (filter_args ):
170171 queried_fields .append (item )
171172 return (
172173 document .objects ()
173174 .no_dereference ()
174175 .only (* set (list (document_field_type ._meta .required_fields ) + queried_fields ))
175- .filter (pk__in = args [0 ][ 1 ])
176+ .filter (pk__in = args [1 ])
176177 )
177178
178179 def get_non_querying_object (* args , ** kwargs ):
179- model = get_document (args [0 ][ 0 ] )
180- return [model (pk = each ) for each in args [0 ][ 1 ]]
180+ model = get_document (args [0 ])
181+ return [model (pk = each ) for each in args [1 ]]
181182
182183 def reference_resolver (root , * args , ** kwargs ):
183184 to_resolve = getattr (root , field .name or field .db_name )
184- if to_resolve :
185- choice_to_resolve = dict ()
186- querying_union_types = list (get_query_fields (args [0 ]).keys ())
187- if "__typename" in querying_union_types :
188- querying_union_types .remove ("__typename" )
189- to_resolve_models = list ()
190- for each in querying_union_types :
191- if executor == ExecutorEnum .SYNC :
192- to_resolve_models .append (registry ._registry_string_map [each ])
193- else :
194- to_resolve_models .append (registry ._registry_async_string_map [each ])
195- to_resolve_object_ids = list ()
196- for each in to_resolve :
197- if isinstance (each , LazyReference ):
198- to_resolve_object_ids .append (each .pk )
199- model = each .document_type ._class_name
200- if model not in choice_to_resolve :
201- choice_to_resolve [model ] = list ()
202- choice_to_resolve [model ].append (each .pk )
203- else :
204- to_resolve_object_ids .append (each ["_ref" ].id )
205- if each ["_cls" ] not in choice_to_resolve :
206- choice_to_resolve [each ["_cls" ]] = list ()
207- choice_to_resolve [each ["_cls" ]].append (each ["_ref" ].id )
208- pool = ThreadPoolExecutor (5 )
209- futures = list ()
210- for model , object_id_list in choice_to_resolve .items ():
211- if model in to_resolve_models :
212- futures .append (
213- pool .submit (
214- get_reference_objects ,
215- (model , object_id_list , registry , args ),
216- )
185+ if not to_resolve :
186+ return None
187+
188+ choice_to_resolve = dict ()
189+ querying_union_types = get_queried_union_types (args [0 ])
190+ to_resolve_models = dict ()
191+ for each , queried_fields in querying_union_types .items ():
192+ to_resolve_models [registry ._registry_string_map [each ]] = queried_fields
193+ to_resolve_object_ids = list ()
194+ for each in to_resolve :
195+ if isinstance (each , LazyReference ):
196+ to_resolve_object_ids .append (each .pk )
197+ model = each .document_type ._class_name
198+ if model not in choice_to_resolve :
199+ choice_to_resolve [model ] = list ()
200+ choice_to_resolve [model ].append (each .pk )
201+ else :
202+ to_resolve_object_ids .append (each ["_ref" ].id )
203+ if each ["_cls" ] not in choice_to_resolve :
204+ choice_to_resolve [each ["_cls" ]] = list ()
205+ choice_to_resolve [each ["_cls" ]].append (each ["_ref" ].id )
206+ pool = ThreadPoolExecutor (5 )
207+ futures = list ()
208+ for model , object_id_list in choice_to_resolve .items ():
209+ if model in to_resolve_models :
210+ queried_fields = to_resolve_models [model ]
211+ futures .append (
212+ pool .submit (
213+ get_reference_objects ,
214+ * (model , object_id_list , registry , args , queried_fields ),
217215 )
218- else :
219- futures . append (
220- pool . submit (
221- get_non_querying_object ,
222- ( model , object_id_list , registry , args ) ,
223- )
216+ )
217+ else :
218+ futures . append (
219+ pool . submit (
220+ get_non_querying_object ,
221+ * ( model , object_id_list , registry , args ),
224222 )
225- result = list ()
226- for x in as_completed (futures ):
227- result += x .result ()
228- result_object_ids = list ()
229- for each in result :
230- result_object_ids .append (each .id )
231- ordered_result = list ()
232- for each in to_resolve_object_ids :
233- ordered_result .append (result [result_object_ids .index (each )])
234- return ordered_result
235- return None
223+ )
224+ result = list ()
225+ for x in as_completed (futures ):
226+ result += x .result ()
227+ result_object_ids = [each .id for each in result ]
228+ ordered_result = [
229+ result [result_object_ids .index (each )] for each in to_resolve_object_ids
230+ ]
231+ return ordered_result
236232
237233 async def get_reference_objects_async (* args , ** kwargs ):
238234 document = get_document (args [0 ])
@@ -247,7 +243,7 @@ async def get_reference_objects_async(*args, **kwargs):
247243 for key , values in document_field_type ._meta .filter_fields .items ():
248244 for each in values :
249245 filter_args .append (key + "__" + each )
250- for each in get_query_fields ( args [3 ][ 0 ])[ document_field_type . _meta . name ]. keys () :
246+ for each in args [4 ] :
251247 item = to_snake_case (each )
252248 if item in document ._fields_ordered + tuple (filter_args ):
253249 queried_fields .append (item )
@@ -259,57 +255,53 @@ async def get_reference_objects_async(*args, **kwargs):
259255 )
260256
261257 async def get_non_querying_object_async (* args , ** kwargs ):
262- model = get_document (args [0 ])
263- return [model (pk = each ) for each in args [1 ]]
258+ return get_non_querying_object (* args , ** kwargs )
264259
265260 async def reference_resolver_async (root , * args , ** kwargs ):
266261 to_resolve = getattr (root , field .name or field .db_name )
267- if to_resolve :
268- choice_to_resolve = dict ()
269- querying_union_types = list (get_query_fields (args [0 ]).keys ())
270- if "__typename" in querying_union_types :
271- querying_union_types .remove ("__typename" )
272- to_resolve_models = list ()
273- for each in querying_union_types :
274- if executor == ExecutorEnum .SYNC :
275- to_resolve_models .append (registry ._registry_string_map [each ])
276- else :
277- to_resolve_models .append (registry ._registry_async_string_map [each ])
278- to_resolve_object_ids = list ()
279- for each in to_resolve :
280- if isinstance (each , LazyReference ):
281- to_resolve_object_ids .append (each .pk )
282- model = each .document_type ._class_name
283- if model not in choice_to_resolve :
284- choice_to_resolve [model ] = list ()
285- choice_to_resolve [model ].append (each .pk )
286- else :
287- to_resolve_object_ids .append (each ["_ref" ].id )
288- if each ["_cls" ] not in choice_to_resolve :
289- choice_to_resolve [each ["_cls" ]] = list ()
290- choice_to_resolve [each ["_cls" ]].append (each ["_ref" ].id )
291- loop = asyncio .get_event_loop ()
292- tasks = []
293- for model , object_id_list in choice_to_resolve .items ():
294- if model in to_resolve_models :
295- task = loop .create_task (
296- get_reference_objects_async (model , object_id_list , registry , args )
297- )
298- else :
299- task = loop .create_task (
300- get_non_querying_object_async (model , object_id_list , registry , args )
262+ if not to_resolve :
263+ return None
264+
265+ choice_to_resolve = dict ()
266+ querying_union_types = get_queried_union_types (args [0 ])
267+ to_resolve_models = dict ()
268+ for each , queried_fields in querying_union_types .items ():
269+ to_resolve_models [registry ._registry_async_string_map [each ]] = queried_fields
270+ to_resolve_object_ids = list ()
271+ for each in to_resolve :
272+ if isinstance (each , LazyReference ):
273+ to_resolve_object_ids .append (each .pk )
274+ model = each .document_type ._class_name
275+ if model not in choice_to_resolve :
276+ choice_to_resolve [model ] = list ()
277+ choice_to_resolve [model ].append (each .pk )
278+ else :
279+ to_resolve_object_ids .append (each ["_ref" ].id )
280+ if each ["_cls" ] not in choice_to_resolve :
281+ choice_to_resolve [each ["_cls" ]] = list ()
282+ choice_to_resolve [each ["_cls" ]].append (each ["_ref" ].id )
283+ loop = asyncio .get_event_loop ()
284+ tasks = []
285+ for model , object_id_list in choice_to_resolve .items ():
286+ if model in to_resolve_models :
287+ queried_fields = to_resolve_models [model ]
288+ task = loop .create_task (
289+ get_reference_objects_async (
290+ model , object_id_list , registry , args , queried_fields
301291 )
302- tasks .append (task )
303- result = await asyncio .gather (* tasks )
304- result_object = {}
305- for items in result :
306- for item in items :
307- result_object [item .id ] = item
308- ordered_result = list ()
309- for each in to_resolve_object_ids :
310- ordered_result .append (result_object [each ])
311- return ordered_result
312- return None
292+ )
293+ else :
294+ task = loop .create_task (
295+ get_non_querying_object_async (model , object_id_list , registry , args )
296+ )
297+ tasks .append (task )
298+ result = await asyncio .gather (* tasks )
299+ result_object = {}
300+ for items in result :
301+ for item in items :
302+ result_object [item .id ] = item
303+ ordered_result = [result_object [each ] for each in to_resolve_object_ids ]
304+ return ordered_result
313305
314306 return graphene .List (
315307 base_type ._type ,
0 commit comments