@@ -22,6 +22,8 @@ namespace Yamlify.SourceGenerator;
2222public sealed class YamlSourceGenerator : IIncrementalGenerator
2323{
2424 private const string YamlSerializableAttribute = "Yamlify.Serialization.YamlSerializableAttribute" ;
25+ private const string YamlSerializableAttributeGeneric = "Yamlify.Serialization.YamlSerializableAttribute<T>" ;
26+ private const string YamlDerivedTypeMappingAttributeGeneric = "Yamlify.Serialization.YamlDerivedTypeMappingAttribute<TBase, TDerived>" ;
2527 private const string YamlSerializerContextBase = "Yamlify.Serialization.YamlSerializerContext" ;
2628
2729 public void Initialize ( IncrementalGeneratorInitializationContext context )
@@ -86,66 +88,137 @@ private static bool IsCandidateClass(SyntaxNode node)
8688 var ignoreEmptyObjects = false ;
8789 var discriminatorPosition = DiscriminatorPositionMode . PropertyOrder ;
8890
91+ // First pass: collect YamlDerivedTypeMapping attributes
92+ // Key: base type display string, Value: list of (discriminator, derivedType)
93+ var derivedTypeMappingsFromAttrs = new Dictionary < string , List < ( string Discriminator , INamedTypeSymbol DerivedType ) > > ( ) ;
94+
95+ foreach ( var attributeData in classSymbol . GetAttributes ( ) )
96+ {
97+ var attrOriginalDef = attributeData . AttributeClass ? . OriginalDefinition ? . ToDisplayString ( ) ;
98+
99+ if ( attrOriginalDef == YamlDerivedTypeMappingAttributeGeneric )
100+ {
101+ // [YamlDerivedTypeMapping<TBase, TDerived>("discriminator")]
102+ if ( attributeData . AttributeClass is { IsGenericType : true , TypeArguments . Length : 2 } attrClass &&
103+ attrClass . TypeArguments [ 0 ] is INamedTypeSymbol mappingBaseType &&
104+ attrClass . TypeArguments [ 1 ] is INamedTypeSymbol mappingDerivedType )
105+ {
106+ var baseTypeKey = mappingBaseType . ToDisplayString ( ) ;
107+
108+ // Get discriminator from constructor argument (optional)
109+ string ? discriminator = null ;
110+ if ( attributeData . ConstructorArguments . Length > 0 &&
111+ attributeData . ConstructorArguments [ 0 ] . Value is string discValue )
112+ {
113+ discriminator = discValue ;
114+ }
115+ discriminator ??= mappingDerivedType . Name ;
116+
117+ if ( ! derivedTypeMappingsFromAttrs . TryGetValue ( baseTypeKey , out var mappings ) )
118+ {
119+ mappings = new List < ( string , INamedTypeSymbol ) > ( ) ;
120+ derivedTypeMappingsFromAttrs [ baseTypeKey ] = mappings ;
121+ }
122+ mappings . Add ( ( discriminator , mappingDerivedType ) ) ;
123+ }
124+ }
125+ }
126+
127+ // Second pass: process YamlSerializable attributes
89128 foreach ( var attributeData in classSymbol . GetAttributes ( ) )
90129 {
91130 var attrName = attributeData . AttributeClass ? . ToDisplayString ( ) ;
131+ var attrOriginalDef = attributeData . AttributeClass ? . OriginalDefinition ? . ToDisplayString ( ) ;
132+
133+ // Support both [YamlSerializable(typeof(T))] and [YamlSerializable<T>]
134+ INamedTypeSymbol ? typeArg = null ;
135+ var isYamlSerializableAttribute = false ;
136+
92137 if ( attrName == YamlSerializableAttribute )
93138 {
139+ // Non-generic: [YamlSerializable(typeof(T))]
94140 if ( attributeData . ConstructorArguments . Length > 0 &&
95- attributeData . ConstructorArguments [ 0 ] . Value is INamedTypeSymbol typeArg )
141+ attributeData . ConstructorArguments [ 0 ] . Value is INamedTypeSymbol ctorArg )
96142 {
97- // Check for per-type PropertyOrdering override
98- PropertyOrderingMode ? typeOrdering = null ;
99- string ? typeDiscriminatorPropertyName = null ;
100- List < INamedTypeSymbol > ? derivedTypes = null ;
101- List < string > ? derivedTypeDiscriminators = null ;
102-
103- foreach ( var namedArg in attributeData . NamedArguments )
143+ typeArg = ctorArg ;
144+ isYamlSerializableAttribute = true ;
145+ }
146+ }
147+ else if ( attrOriginalDef == YamlSerializableAttributeGeneric )
148+ {
149+ // Generic: [YamlSerializable<T>]
150+ if ( attributeData . AttributeClass is { IsGenericType : true , TypeArguments . Length : > 0 } attrClass &&
151+ attrClass . TypeArguments [ 0 ] is INamedTypeSymbol genericArg )
152+ {
153+ typeArg = genericArg ;
154+ isYamlSerializableAttribute = true ;
155+ }
156+ }
157+
158+ if ( isYamlSerializableAttribute && typeArg is not null )
159+ {
160+ // Check for per-type PropertyOrdering override
161+ PropertyOrderingMode ? typeOrdering = null ;
162+ string ? typeDiscriminatorPropertyName = null ;
163+ List < INamedTypeSymbol > ? derivedTypes = null ;
164+ List < string > ? derivedTypeDiscriminators = null ;
165+
166+ foreach ( var namedArg in attributeData . NamedArguments )
167+ {
168+ if ( namedArg . Key == "PropertyOrdering" && namedArg . Value . Value is int orderingValue && orderingValue >= 0 )
104169 {
105- if ( namedArg . Key == "PropertyOrdering" && namedArg . Value . Value is int orderingValue && orderingValue >= 0 )
106- {
107- // Only set if not Inherit (-1)
108- typeOrdering = ( PropertyOrderingMode ) orderingValue ;
109- }
110- else if ( namedArg . Key == "TypeDiscriminatorPropertyName" && namedArg . Value . Value is string discPropName )
111- {
112- typeDiscriminatorPropertyName = discPropName ;
113- }
114- else if ( namedArg . Key == "DerivedTypes" && ! namedArg . Value . IsNull )
115- {
116- derivedTypes = namedArg . Value . Values
117- . Where ( v => v . Value is INamedTypeSymbol )
118- . Select ( v => ( INamedTypeSymbol ) v . Value ! )
119- . ToList ( ) ;
120- }
121- else if ( namedArg . Key == "DerivedTypeDiscriminators" && ! namedArg . Value . IsNull )
122- {
123- derivedTypeDiscriminators = namedArg . Value . Values
124- . Where ( v => v . Value is string )
125- . Select ( v => ( string ) v . Value ! )
126- . ToList ( ) ;
127- }
170+ // Only set if not Inherit (-1)
171+ typeOrdering = ( PropertyOrderingMode ) orderingValue ;
128172 }
129-
130- // Build PolymorphicInfo if polymorphic configuration is specified
131- PolymorphicInfo ? polymorphicConfig = null ;
132- if ( typeDiscriminatorPropertyName is not null && derivedTypes is not null && derivedTypes . Count > 0 )
173+ else if ( namedArg . Key == "TypeDiscriminatorPropertyName" && namedArg . Value . Value is string discPropName )
133174 {
134- var derivedTypeMappings = new List < ( string Discriminator , INamedTypeSymbol DerivedType ) > ( ) ;
135- for ( int i = 0 ; i < derivedTypes . Count ; i ++ )
136- {
137- var derivedType = derivedTypes [ i ] ;
138- // Use explicit discriminator if provided, otherwise use type name
139- var discriminator = ( derivedTypeDiscriminators is not null && i < derivedTypeDiscriminators . Count )
140- ? derivedTypeDiscriminators [ i ]
141- : derivedType . Name ;
142- derivedTypeMappings . Add ( ( discriminator , derivedType ) ) ;
143- }
144- polymorphicConfig = new PolymorphicInfo ( typeDiscriminatorPropertyName , derivedTypeMappings ) ;
175+ typeDiscriminatorPropertyName = discPropName ;
145176 }
146-
147- typesToGenerate . Add ( new TypeToGenerate ( typeArg , typeOrdering , polymorphicConfig ) ) ;
177+ else if ( namedArg . Key == "DerivedTypes" && ! namedArg . Value . IsNull )
178+ {
179+ derivedTypes = namedArg . Value . Values
180+ . Where ( v => v . Value is INamedTypeSymbol )
181+ . Select ( v => ( INamedTypeSymbol ) v . Value ! )
182+ . ToList ( ) ;
183+ }
184+ else if ( namedArg . Key == "DerivedTypeDiscriminators" && ! namedArg . Value . IsNull )
185+ {
186+ derivedTypeDiscriminators = namedArg . Value . Values
187+ . Where ( v => v . Value is string )
188+ . Select ( v => ( string ) v . Value ! )
189+ . ToList ( ) ;
190+ }
191+ }
192+
193+ // Build PolymorphicInfo if polymorphic configuration is specified
194+ PolymorphicInfo ? polymorphicConfig = null ;
195+ var typeKey = typeArg . ToDisplayString ( ) ;
196+
197+ // Check for derived type mappings from YamlDerivedTypeMappingAttribute
198+ if ( typeDiscriminatorPropertyName is not null &&
199+ derivedTypeMappingsFromAttrs . TryGetValue ( typeKey , out var mappingsFromAttr ) &&
200+ mappingsFromAttr . Count > 0 )
201+ {
202+ // Use mappings from YamlDerivedTypeMappingAttribute
203+ polymorphicConfig = new PolymorphicInfo ( typeDiscriminatorPropertyName , mappingsFromAttr ) ;
148204 }
205+ else if ( typeDiscriminatorPropertyName is not null && derivedTypes is not null && derivedTypes . Count > 0 )
206+ {
207+ // Use inline DerivedTypes/DerivedTypeDiscriminators arrays
208+ var derivedTypeMappings = new List < ( string Discriminator , INamedTypeSymbol DerivedType ) > ( ) ;
209+ for ( int i = 0 ; i < derivedTypes . Count ; i ++ )
210+ {
211+ var derivedType = derivedTypes [ i ] ;
212+ // Use explicit discriminator if provided, otherwise use type name
213+ var discriminator = ( derivedTypeDiscriminators is not null && i < derivedTypeDiscriminators . Count )
214+ ? derivedTypeDiscriminators [ i ]
215+ : derivedType . Name ;
216+ derivedTypeMappings . Add ( ( discriminator , derivedType ) ) ;
217+ }
218+ polymorphicConfig = new PolymorphicInfo ( typeDiscriminatorPropertyName , derivedTypeMappings ) ;
219+ }
220+
221+ typesToGenerate . Add ( new TypeToGenerate ( typeArg , typeOrdering , polymorphicConfig ) ) ;
149222 }
150223 else if ( attrName == "Yamlify.Serialization.YamlSourceGenerationOptionsAttribute" )
151224 {
0 commit comments