@@ -390,6 +390,17 @@ def gather_types(cls, _type):
390390 _ctype = toscan .pop ()
391391 my_node_name = _ctype .__idl_typename__ .replace ('.' , '::' )
392392
393+ if isclass (_ctype ) and issubclass (_ctype , IdlUnion ):
394+ # get_extended_type_hints will not inspect the discriminator, and that can be an enum
395+ discriminator_type = _ctype .__idl_discriminator__
396+ if isclass (discriminator_type ) and issubclass (discriminator_type , IdlEnum ):
397+ scan_node_name = discriminator_type .__idl_typename__ .replace ('.' , '::' )
398+ if scan_node_name not in graph :
399+ graph [scan_node_name ] = set ()
400+ graph_types [scan_node_name ] = discriminator_type
401+
402+ graph [my_node_name ].add (scan_node_name )
403+
393404 for name , fieldtype in get_extended_type_hints (_ctype ).items ():
394405 m , deep = cls ._deep_gather_type (fieldtype )
395406 plain = cls ._impl_xt_is_plain (fieldtype )
@@ -1071,7 +1082,7 @@ def _xt_minimal_discriminator_member(cls, entity: Type[IdlUnion]) -> xt.MinimalD
10711082 @classmethod
10721083 def _xt_complete_discriminator_member (cls , entity : Type [IdlUnion ]) -> xt .CompleteDiscriminatorMember :
10731084 return xt .CompleteDiscriminatorMember (
1074- common = cls ._xt_common_discriminator_member (entity , True ),
1085+ common = cls ._xt_common_discriminator_member (entity , False ),
10751086 ann_builtin = None ,
10761087 ann_custom = None ,
10771088 )
0 commit comments