Skip to content

Commit 1231952

Browse files
committed
fix[union-converter]: get_queried_union_types can now handle union fragments
1 parent 08295c6 commit 1231952

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

graphene_mongo/converter.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,9 @@ def reference_resolver(root, *args, **kwargs):
186186
return None
187187

188188
choice_to_resolve = dict()
189-
querying_union_types = get_queried_union_types(args[0])
189+
querying_union_types = get_queried_union_types(
190+
info=args[0], valid_gql_types=registry._registry_string_map.keys()
191+
)
190192
to_resolve_models = dict()
191193
for each, queried_fields in querying_union_types.items():
192194
to_resolve_models[registry._registry_string_map[each]] = queried_fields
@@ -263,7 +265,9 @@ async def reference_resolver_async(root, *args, **kwargs):
263265
return None
264266

265267
choice_to_resolve = dict()
266-
querying_union_types = get_queried_union_types(args[0])
268+
querying_union_types = get_queried_union_types(
269+
info=args[0], valid_gql_types=registry._registry_async_string_map.keys()
270+
)
267271
to_resolve_models = dict()
268272
for each, queried_fields in querying_union_types.items():
269273
to_resolve_models[registry._registry_async_string_map[each]] = queried_fields

graphene_mongo/utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,11 +203,12 @@ def get_query_fields(info):
203203
return query
204204

205205

206-
def get_queried_union_types(info):
206+
def get_queried_union_types(info, valid_gql_types):
207207
"""A convenience function to get queried union types with its fields
208208
209209
Args:
210210
info (ResolveInfo)
211+
valid_gql_types (dict_keys)
211212
212213
Returns:
213214
dict[union_type_name, queried_fields(dict)]
@@ -227,9 +228,16 @@ def get_queried_union_types(info):
227228
for leaf in selection_set.selections:
228229
if leaf.kind == "fragment_spread":
229230
fragment_name = fragments[leaf.name.value].type_condition.name.value
230-
fragments_queries[fragment_name] = collect_query_fields(
231+
sub_query_fields = collect_query_fields(
231232
fragments[leaf.name.value], fragments, variables
232233
)
234+
if fragment_name not in valid_gql_types:
235+
# This is done to avoid UnionFragments coming in fragments_queries as
236+
# we actually need its children types and not the UnionFragments itself
237+
fragments_queries.update(sub_query_fields)
238+
fragments_queries.pop('__typename', None) # cannot resolve __typename for a union type
239+
else:
240+
fragments_queries[fragment_name] = sub_query_fields
233241
elif leaf.kind == "inline_fragment":
234242
fragment_name = leaf.type_condition.name.value
235243
fragments_queries[fragment_name] = collect_query_fields(leaf, fragments, variables)

0 commit comments

Comments
 (0)