Skip to content

Commit 9a594fe

Browse files
committed
fix[converter]: convert_field_to_list resolver error
new get_query_fields cannot find union types called. Introduced get_queried_union_types to find it
1 parent fb7a4bd commit 9a594fe

File tree

2 files changed

+129
-103
lines changed

2 files changed

+129
-103
lines changed

graphene_mongo/converter.py

Lines changed: 95 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .utils import (
1212
get_field_description,
1313
get_query_fields,
14+
get_queried_union_types,
1415
get_field_is_required,
1516
ExecutorEnum,
1617
sync_to_async,
@@ -154,7 +155,7 @@ def convert_field_to_list(field, registry=None, executor: ExecutorEnum = Executo
154155
if isinstance(field.field, mongoengine.GenericReferenceField):
155156

156157
def get_reference_objects(*args, **kwargs):
157-
document = get_document(args[0][0])
158+
document = get_document(args[0])
158159
document_field = mongoengine.ReferenceField(document)
159160
document_field = convert_mongoengine_field(document_field, registry)
160161
document_field_type = document_field.get_type().type
@@ -164,75 +165,70 @@ def get_reference_objects(*args, **kwargs):
164165
for key, values in document_field_type._meta.filter_fields.items():
165166
for each in values:
166167
filter_args.append(key + "__" + each)
167-
for each in get_query_fields(args[0][3][0])[document_field_type._meta.name].keys():
168+
for each in args[4]:
168169
item = to_snake_case(each)
169170
if item in document._fields_ordered + tuple(filter_args):
170171
queried_fields.append(item)
171172
return (
172173
document.objects()
173174
.no_dereference()
174175
.only(*set(list(document_field_type._meta.required_fields) + queried_fields))
175-
.filter(pk__in=args[0][1])
176+
.filter(pk__in=args[1])
176177
)
177178

178179
def get_non_querying_object(*args, **kwargs):
179-
model = get_document(args[0][0])
180-
return [model(pk=each) for each in args[0][1]]
180+
model = get_document(args[0])
181+
return [model(pk=each) for each in args[1]]
181182

182183
def reference_resolver(root, *args, **kwargs):
183184
to_resolve = getattr(root, field.name or field.db_name)
184-
if to_resolve:
185-
choice_to_resolve = dict()
186-
querying_union_types = list(get_query_fields(args[0]).keys())
187-
if "__typename" in querying_union_types:
188-
querying_union_types.remove("__typename")
189-
to_resolve_models = list()
190-
for each in querying_union_types:
191-
if executor == ExecutorEnum.SYNC:
192-
to_resolve_models.append(registry._registry_string_map[each])
193-
else:
194-
to_resolve_models.append(registry._registry_async_string_map[each])
195-
to_resolve_object_ids = list()
196-
for each in to_resolve:
197-
if isinstance(each, LazyReference):
198-
to_resolve_object_ids.append(each.pk)
199-
model = each.document_type._class_name
200-
if model not in choice_to_resolve:
201-
choice_to_resolve[model] = list()
202-
choice_to_resolve[model].append(each.pk)
203-
else:
204-
to_resolve_object_ids.append(each["_ref"].id)
205-
if each["_cls"] not in choice_to_resolve:
206-
choice_to_resolve[each["_cls"]] = list()
207-
choice_to_resolve[each["_cls"]].append(each["_ref"].id)
208-
pool = ThreadPoolExecutor(5)
209-
futures = list()
210-
for model, object_id_list in choice_to_resolve.items():
211-
if model in to_resolve_models:
212-
futures.append(
213-
pool.submit(
214-
get_reference_objects,
215-
(model, object_id_list, registry, args),
216-
)
185+
if not to_resolve:
186+
return None
187+
188+
choice_to_resolve = dict()
189+
querying_union_types = get_queried_union_types(args[0])
190+
to_resolve_models = dict()
191+
for each, queried_fields in querying_union_types.items():
192+
to_resolve_models[registry._registry_string_map[each]] = queried_fields
193+
to_resolve_object_ids = list()
194+
for each in to_resolve:
195+
if isinstance(each, LazyReference):
196+
to_resolve_object_ids.append(each.pk)
197+
model = each.document_type._class_name
198+
if model not in choice_to_resolve:
199+
choice_to_resolve[model] = list()
200+
choice_to_resolve[model].append(each.pk)
201+
else:
202+
to_resolve_object_ids.append(each["_ref"].id)
203+
if each["_cls"] not in choice_to_resolve:
204+
choice_to_resolve[each["_cls"]] = list()
205+
choice_to_resolve[each["_cls"]].append(each["_ref"].id)
206+
pool = ThreadPoolExecutor(5)
207+
futures = list()
208+
for model, object_id_list in choice_to_resolve.items():
209+
if model in to_resolve_models:
210+
queried_fields = to_resolve_models[model]
211+
futures.append(
212+
pool.submit(
213+
get_reference_objects,
214+
*(model, object_id_list, registry, args, queried_fields),
217215
)
218-
else:
219-
futures.append(
220-
pool.submit(
221-
get_non_querying_object,
222-
(model, object_id_list, registry, args),
223-
)
216+
)
217+
else:
218+
futures.append(
219+
pool.submit(
220+
get_non_querying_object,
221+
*(model, object_id_list, registry, args),
224222
)
225-
result = list()
226-
for x in as_completed(futures):
227-
result += x.result()
228-
result_object_ids = list()
229-
for each in result:
230-
result_object_ids.append(each.id)
231-
ordered_result = list()
232-
for each in to_resolve_object_ids:
233-
ordered_result.append(result[result_object_ids.index(each)])
234-
return ordered_result
235-
return None
223+
)
224+
result = list()
225+
for x in as_completed(futures):
226+
result += x.result()
227+
result_object_ids = [each.id for each in result]
228+
ordered_result = [
229+
result[result_object_ids.index(each)] for each in to_resolve_object_ids
230+
]
231+
return ordered_result
236232

237233
async def get_reference_objects_async(*args, **kwargs):
238234
document = get_document(args[0])
@@ -247,7 +243,7 @@ async def get_reference_objects_async(*args, **kwargs):
247243
for key, values in document_field_type._meta.filter_fields.items():
248244
for each in values:
249245
filter_args.append(key + "__" + each)
250-
for each in get_query_fields(args[3][0])[document_field_type._meta.name].keys():
246+
for each in args[4]:
251247
item = to_snake_case(each)
252248
if item in document._fields_ordered + tuple(filter_args):
253249
queried_fields.append(item)
@@ -259,57 +255,53 @@ async def get_reference_objects_async(*args, **kwargs):
259255
)
260256

261257
async def get_non_querying_object_async(*args, **kwargs):
262-
model = get_document(args[0])
263-
return [model(pk=each) for each in args[1]]
258+
return get_non_querying_object(*args, **kwargs)
264259

265260
async def reference_resolver_async(root, *args, **kwargs):
266261
to_resolve = getattr(root, field.name or field.db_name)
267-
if to_resolve:
268-
choice_to_resolve = dict()
269-
querying_union_types = list(get_query_fields(args[0]).keys())
270-
if "__typename" in querying_union_types:
271-
querying_union_types.remove("__typename")
272-
to_resolve_models = list()
273-
for each in querying_union_types:
274-
if executor == ExecutorEnum.SYNC:
275-
to_resolve_models.append(registry._registry_string_map[each])
276-
else:
277-
to_resolve_models.append(registry._registry_async_string_map[each])
278-
to_resolve_object_ids = list()
279-
for each in to_resolve:
280-
if isinstance(each, LazyReference):
281-
to_resolve_object_ids.append(each.pk)
282-
model = each.document_type._class_name
283-
if model not in choice_to_resolve:
284-
choice_to_resolve[model] = list()
285-
choice_to_resolve[model].append(each.pk)
286-
else:
287-
to_resolve_object_ids.append(each["_ref"].id)
288-
if each["_cls"] not in choice_to_resolve:
289-
choice_to_resolve[each["_cls"]] = list()
290-
choice_to_resolve[each["_cls"]].append(each["_ref"].id)
291-
loop = asyncio.get_event_loop()
292-
tasks = []
293-
for model, object_id_list in choice_to_resolve.items():
294-
if model in to_resolve_models:
295-
task = loop.create_task(
296-
get_reference_objects_async(model, object_id_list, registry, args)
297-
)
298-
else:
299-
task = loop.create_task(
300-
get_non_querying_object_async(model, object_id_list, registry, args)
262+
if not to_resolve:
263+
return None
264+
265+
choice_to_resolve = dict()
266+
querying_union_types = get_queried_union_types(args[0])
267+
to_resolve_models = dict()
268+
for each, queried_fields in querying_union_types.items():
269+
to_resolve_models[registry._registry_async_string_map[each]] = queried_fields
270+
to_resolve_object_ids = list()
271+
for each in to_resolve:
272+
if isinstance(each, LazyReference):
273+
to_resolve_object_ids.append(each.pk)
274+
model = each.document_type._class_name
275+
if model not in choice_to_resolve:
276+
choice_to_resolve[model] = list()
277+
choice_to_resolve[model].append(each.pk)
278+
else:
279+
to_resolve_object_ids.append(each["_ref"].id)
280+
if each["_cls"] not in choice_to_resolve:
281+
choice_to_resolve[each["_cls"]] = list()
282+
choice_to_resolve[each["_cls"]].append(each["_ref"].id)
283+
loop = asyncio.get_event_loop()
284+
tasks = []
285+
for model, object_id_list in choice_to_resolve.items():
286+
if model in to_resolve_models:
287+
queried_fields = to_resolve_models[model]
288+
task = loop.create_task(
289+
get_reference_objects_async(
290+
model, object_id_list, registry, args, queried_fields
301291
)
302-
tasks.append(task)
303-
result = await asyncio.gather(*tasks)
304-
result_object = {}
305-
for items in result:
306-
for item in items:
307-
result_object[item.id] = item
308-
ordered_result = list()
309-
for each in to_resolve_object_ids:
310-
ordered_result.append(result_object[each])
311-
return ordered_result
312-
return None
292+
)
293+
else:
294+
task = loop.create_task(
295+
get_non_querying_object_async(model, object_id_list, registry, args)
296+
)
297+
tasks.append(task)
298+
result = await asyncio.gather(*tasks)
299+
result_object = {}
300+
for items in result:
301+
for item in items:
302+
result_object[item.id] = item
303+
ordered_result = [result_object[each] for each in to_resolve_object_ids]
304+
return ordered_result
313305

314306
return graphene.List(
315307
base_type._type,

graphene_mongo/utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,40 @@ def get_query_fields(info):
203203
return query
204204

205205

206+
def get_queried_union_types(info):
207+
"""A convenience function to get queried union types with its fields
208+
209+
Args:
210+
info (ResolveInfo)
211+
212+
Returns:
213+
dict[union_type_name, queried_fields(dict)]
214+
"""
215+
216+
fragments = {}
217+
node = ast_to_dict(info.field_nodes[0])
218+
variables = info.variable_values
219+
220+
for name, value in info.fragments.items():
221+
fragments[name] = ast_to_dict(value)
222+
223+
fragments_queries: dict[str, dict] = {}
224+
225+
selection_set = node.get("selection_set") if isinstance(node, dict) else node.selection_set
226+
if selection_set:
227+
for leaf in selection_set.selections:
228+
if leaf.kind == "fragment_spread":
229+
fragment_name = fragments[leaf.name.value].type_condition.name.value
230+
fragments_queries[fragment_name] = collect_query_fields(
231+
fragments[leaf.name.value], fragments, variables
232+
)
233+
elif leaf.kind == "inline_fragment":
234+
fragment_name = leaf.type_condition.name.value
235+
fragments_queries[fragment_name] = collect_query_fields(leaf, fragments, variables)
236+
237+
return fragments_queries
238+
239+
206240
def has_page_info(info):
207241
"""A convenience function to call collect_query_fields with info
208242
for retrieving if page_info details are required

0 commit comments

Comments
 (0)