@@ -214,6 +214,42 @@ def get_queried_union_types(info, valid_gql_types):
214214 dict[union_type_name, queried_fields(dict)]
215215 """
216216
217+ def collect_query_fields_for_union (node , fragments , variables ):
218+ """
219+ Similar to collect_query_fields(...)
220+
221+ fragment_spread - logic is different for union
222+ """
223+
224+ field = {}
225+ selection_set = node .get ("selection_set" ) if isinstance (node , dict ) else node .selection_set
226+ if selection_set :
227+ for leaf in selection_set .selections :
228+ if leaf .kind == "field" :
229+ if include_field_by_directives (leaf , variables ):
230+ field .update (
231+ {leaf .name .value : collect_query_fields (leaf , fragments , variables )}
232+ )
233+ elif leaf .kind == "fragment_spread" : # This is different
234+ fragment = fragments [leaf .name .value ]
235+ field .update (
236+ {
237+ fragment .type_condition .name .value : collect_query_fields (
238+ fragment , fragments , variables
239+ )
240+ }
241+ )
242+ elif leaf .kind == "inline_fragment" :
243+ field .update (
244+ {
245+ leaf .type_condition .name .value : collect_query_fields (
246+ leaf , fragments , variables
247+ )
248+ }
249+ )
250+
251+ return field
252+
217253 fragments = {}
218254 node = ast_to_dict (info .field_nodes [0 ])
219255 variables = info .variable_values
@@ -228,7 +264,7 @@ def get_queried_union_types(info, valid_gql_types):
228264 for leaf in selection_set .selections :
229265 if leaf .kind == "fragment_spread" :
230266 fragment_name = fragments [leaf .name .value ].type_condition .name .value
231- sub_query_fields = collect_query_fields (
267+ sub_query_fields = collect_query_fields_for_union (
232268 fragments [leaf .name .value ], fragments , variables
233269 )
234270 if fragment_name not in valid_gql_types :
@@ -242,7 +278,9 @@ def get_queried_union_types(info, valid_gql_types):
242278 fragments_queries [fragment_name ] = sub_query_fields
243279 elif leaf .kind == "inline_fragment" :
244280 fragment_name = leaf .type_condition .name .value
245- fragments_queries [fragment_name ] = collect_query_fields (leaf , fragments , variables )
281+ fragments_queries [fragment_name ] = collect_query_fields_for_union (
282+ leaf , fragments , variables
283+ )
246284
247285 return fragments_queries
248286
0 commit comments