Skip to content

Commit b6b1ad6

Browse files
committed
fix[union-type]: union type resolution caused errors when using
fragment spread inside a UnionType
1 parent ff700b3 commit b6b1ad6

File tree

1 file changed

+40
-2
lines changed

1 file changed

+40
-2
lines changed

graphene_mongo/utils.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,42 @@ def get_queried_union_types(info, valid_gql_types):
214214
dict[union_type_name, queried_fields(dict)]
215215
"""
216216

217+
def collect_query_fields_for_union(node, fragments, variables):
218+
"""
219+
Similar to collect_query_fields(...)
220+
221+
fragment_spread - logic is different for union
222+
"""
223+
224+
field = {}
225+
selection_set = node.get("selection_set") if isinstance(node, dict) else node.selection_set
226+
if selection_set:
227+
for leaf in selection_set.selections:
228+
if leaf.kind == "field":
229+
if include_field_by_directives(leaf, variables):
230+
field.update(
231+
{leaf.name.value: collect_query_fields(leaf, fragments, variables)}
232+
)
233+
elif leaf.kind == "fragment_spread": # This is different
234+
fragment = fragments[leaf.name.value]
235+
field.update(
236+
{
237+
fragment.type_condition.name.value: collect_query_fields(
238+
fragment, fragments, variables
239+
)
240+
}
241+
)
242+
elif leaf.kind == "inline_fragment":
243+
field.update(
244+
{
245+
leaf.type_condition.name.value: collect_query_fields(
246+
leaf, fragments, variables
247+
)
248+
}
249+
)
250+
251+
return field
252+
217253
fragments = {}
218254
node = ast_to_dict(info.field_nodes[0])
219255
variables = info.variable_values
@@ -228,7 +264,7 @@ def get_queried_union_types(info, valid_gql_types):
228264
for leaf in selection_set.selections:
229265
if leaf.kind == "fragment_spread":
230266
fragment_name = fragments[leaf.name.value].type_condition.name.value
231-
sub_query_fields = collect_query_fields(
267+
sub_query_fields = collect_query_fields_for_union(
232268
fragments[leaf.name.value], fragments, variables
233269
)
234270
if fragment_name not in valid_gql_types:
@@ -242,7 +278,9 @@ def get_queried_union_types(info, valid_gql_types):
242278
fragments_queries[fragment_name] = sub_query_fields
243279
elif leaf.kind == "inline_fragment":
244280
fragment_name = leaf.type_condition.name.value
245-
fragments_queries[fragment_name] = collect_query_fields(leaf, fragments, variables)
281+
fragments_queries[fragment_name] = collect_query_fields_for_union(
282+
leaf, fragments, variables
283+
)
246284

247285
return fragments_queries
248286

0 commit comments

Comments
 (0)