44using Microsoft . CodeAnalysis . CSharp . Syntax ;
55using static System . StringComparison ;
66using static Constants ;
7+ using static Microsoft . CodeAnalysis . Accessibility ;
78
89[ Generator ]
910public sealed class UnionStructSourceGenerator : IIncrementalGenerator
@@ -30,7 +31,6 @@ record struct DerivedInput
3031 public string Name { get ; init ; }
3132 public string FQN { get ; init ; }
3233 public EquatableArray < string > Declaration { get ; init ; }
33- public string HeaderFieldName { get ; init ; }
3434 public byte AssignedId { get ; init ; }
3535 public EquatableArray < string > PubliclyImplementedMembers { get ; init ; }
3636 public EquatableArray < string > MemberNames { get ; init ; }
@@ -53,14 +53,13 @@ record struct Derived : IGeneratorTransformOutput
5353 public string DerivedFQN { get ; init ; }
5454 public string DerivedName { get ; init ; }
5555 public EquatableArray < string > Declaration { get ; init ; }
56- public string HeaderFieldName { get ; init ; }
57- public EquatableArray < string > HeaderChainFQNs { get ; init ; }
5856 public byte ? ForcedId { get ; init ; }
5957
58+ public EquatableIgnore < Func < string , bool > ? > HasHeaderInChainFunc { get ; init ; }
6059 public EquatableIgnore < Func < string , bool > ? > ImplementsUnionInterfaceFunc { get ; init ; }
6160 public EquatableIgnore < Func < DerivedDeferredInput > ? > DeferredInputBuilderFunc { get ; init ; }
6261
63- // ReSharper disable once NotAccessedField .Local
62+ // ReSharper disable once UnusedAutoPropertyAccessor .Local
6463 public EquatableArray < byte > DerivedTextCheckSumForCache { get ; init ; }
6564 }
6665
@@ -74,11 +73,9 @@ record struct GeneratorInput : IGeneratorTransformOutput
7473 public EquatableArray < string > BaseDeclaration { get ; init ; }
7574 public string BaseTypeName { get ; init ; }
7675 public string BaseTypeFQN { get ; init ; }
77- public string InterfaceName { get ; init ; }
7876 public string InterfaceFQN { get ; init ; }
7977 public string TypeIDEnumFQN { get ; init ; }
8078 public string TypeIDFieldName { get ; init ; }
81- public string RootTypeName { get ; init ; }
8279 public string RootTypeFQN { get ; init ; }
8380 public string RootInterfaceFQN { get ; init ; }
8481 public string RootTypeIDEnumFQN { get ; init ; }
@@ -92,7 +89,7 @@ record struct GeneratorInput : IGeneratorTransformOutput
9289 public EquatableIgnore < Func < HeaderFieldInput [ ] > ? > HeaderFieldsBuilderFunc { get ; init ; }
9390 public EquatableIgnore < Func < DerivedInput [ ] > ? > DerivedStructsBuilderFunc { get ; init ; }
9491
95- // ReSharper disable once NotAccessedField .Local
92+ // ReSharper disable once UnusedAutoPropertyAccessor .Local
9693 public EquatableArray < byte > BaseTextCheckSumForCache { get ; init ; }
9794 }
9895
@@ -132,18 +129,22 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
132129 static GeneratorInput TransformBase ( GeneratorAttributeSyntaxContext context , CancellationToken ct )
133130 {
134131 if ( context is not { TargetSymbol : ITypeSymbol symbol , TargetNode : StructDeclarationSyntax structDecl } )
132+ {
135133 return new ( )
136134 {
137135 SourceGeneratorError = "Unexpected target shape for [UnionHeader]." ,
138136 SourceGeneratorErrorLocation = new LocationInfo ( context . TargetNode . GetLocation ( ) ) ,
139137 } ;
138+ }
140139
141140 if ( symbol is not INamedTypeSymbol baseSymbol )
141+ {
142142 return new ( )
143143 {
144144 SourceGeneratorError = "Unexpected target symbol for [UnionHeader]." ,
145145 SourceGeneratorErrorLocation = new LocationInfo ( context . TargetNode . GetLocation ( ) ) ,
146146 } ;
147+ }
147148
148149 var symbolMembers = symbol . GetMembers ( ) . AsArray ( ) ;
149150 var symbolTypeMembers = symbol . GetTypeMembers ( ) . AsArray ( ) ;
@@ -152,10 +153,12 @@ static GeneratorInput TransformBase(GeneratorAttributeSyntaxContext context, Can
152153 ?? symbolTypeMembers . FirstOrDefault ( x => x . TypeKind is TypeKind . Interface ) ;
153154
154155 if ( interfaceSymbol is null )
156+ {
155157 return new ( )
156158 {
157159 SourceGeneratorOutputFilename = Utility . GetOutputFilename ( structDecl . SyntaxTree . FilePath , symbol . Name , "Union" ) ,
158160 } ;
161+ }
159162
160163 var typeIDEnumSymbol = symbolTypeMembers . FirstOrDefault ( x => x . Name is "TypeIDs" ) ;
161164 var typeIDField = symbolMembers . FirstOrDefault ( x => x is IFieldSymbol { Name : "TypeID" , Type . Name : "TypeIDs" } ) ;
@@ -188,13 +191,11 @@ parentInterface is not null &&
188191 BaseDeclaration = Utility . DeconstructTypeDeclaration ( structDecl , context . SemanticModel , ct ) ,
189192 BaseTypeName = symbol . Name ,
190193 BaseTypeFQN = symbol . FQN ,
191- InterfaceName = interfaceSymbol . Name ,
192194 InterfaceFQN = interfaceSymbol . FQN ,
193195 TypeIDEnumFQN = isRootTypeIdOwner
194196 ? typeIDEnumSymbol ? . FQN ?? $ "{ symbol . FQN } .TypeIDs"
195197 : rootTypeIDEnumSymbol ? . FQN ?? $ "{ rootHeader . FQN } .TypeIDs",
196198 TypeIDFieldName = typeIDField ? . Name ?? "TypeID" ,
197- RootTypeName = rootHeader . Name ,
198199 RootTypeFQN = rootHeader . FQN ,
199200 RootInterfaceFQN = rootInterface . FQN ,
200201 RootTypeIDEnumFQN = rootTypeIDEnumSymbol ? . FQN ?? $ "{ rootHeader . FQN } .TypeIDs",
@@ -257,24 +258,18 @@ static InterfaceMemberInput[] BuildInterfaceMembers(INamedTypeSymbol interfaceSy
257258
258259 static HeaderFieldInput [ ] BuildHeaderFields ( INamedTypeSymbol headerSymbol )
259260 {
260- static bool IsAccessible ( Accessibility accessibility )
261- => accessibility is
262- Accessibility . Public
263- or Accessibility . Internal
264- or Accessibility . ProtectedOrInternal ;
265-
266261 var fieldsAndProperties = headerSymbol . GetMembers ( )
267262 . Where ( x => x is IFieldSymbol or IPropertySymbol )
268263 . OrderBy ( x => x . Locations . FirstOrDefault ( ) ? . SourceSpan . Start ?? int . MaxValue ) ;
269264
270- var result = new List < HeaderFieldInput > ( ) ;
265+ var result = new List < HeaderFieldInput > ( capacity : 8 ) ;
271266 foreach ( var member in fieldsAndProperties )
272267 {
273268 switch ( member )
274269 {
275270 case IFieldSymbol { IsStatic : false , IsImplicitlyDeclared : false } field :
276271 {
277- if ( ! IsAccessible ( field . DeclaredAccessibility ) )
272+ if ( ! field . IsAccessible )
278273 break ;
279274
280275 result . Add (
@@ -290,8 +285,9 @@ or Accessibility.Internal
290285 }
291286 case IPropertySymbol { IsStatic : false , IsIndexer : false } property :
292287 {
293- bool canGet = property . GetMethod is { } getter && IsAccessible ( getter . DeclaredAccessibility ) ;
294- bool canSet = property . SetMethod is { } setter && IsAccessible ( setter . DeclaredAccessibility ) ;
288+ bool canGet = property . GetMethod is { IsAccessible : true } ;
289+ bool canSet = property . SetMethod is { IsAccessible : true } ;
290+
295291 if ( ! canGet && ! canSet )
296292 break ;
297293
@@ -317,7 +313,7 @@ static Derived TransformDerivedCandidate(GeneratorAttributeSyntaxContext context
317313 if ( context . TargetNode is not StructDeclarationSyntax structDecl )
318314 return default ;
319315
320- if ( context . SemanticModel . GetDeclaredSymbol ( structDecl , ct ) is not INamedTypeSymbol symbol )
316+ if ( context . SemanticModel . GetDeclaredSymbol ( structDecl , ct ) is not { } symbol )
321317 return default ;
322318
323319 byte ? forcedId = context . Attributes
@@ -333,9 +329,8 @@ static Derived TransformDerivedCandidate(GeneratorAttributeSyntaxContext context
333329 DerivedFQN = symbol . FQN ,
334330 DerivedName = symbol . Name ,
335331 Declaration = Utility . DeconstructTypeDeclaration ( structDecl , context . SemanticModel , ct ) ,
336- HeaderFieldName = GetFirstHeaderField ( symbol ) ? . Name ?? "Header" ,
337- HeaderChainFQNs = BuildHeaderChainFQNs ( symbol ) ,
338332 ForcedId = forcedId ,
333+ HasHeaderInChainFunc = new ( headerFQN => HasHeaderInChain ( symbol , headerFQN ) ) ,
339334 ImplementsUnionInterfaceFunc = new ( interfaceFQN => symbol . AllInterfaces . Any ( x => x . FQN == interfaceFQN ) ) ,
340335 DeferredInputBuilderFunc = new ( ( ) => BuildDerivedDeferredInput ( symbol ) ) ,
341336 DerivedTextCheckSumForCache = structDecl . GetText ( ) . GetChecksum ( ) . AsArray ( ) ,
@@ -345,7 +340,7 @@ static Derived TransformDerivedCandidate(GeneratorAttributeSyntaxContext context
345340 static DerivedDeferredInput BuildDerivedDeferredInput ( INamedTypeSymbol symbol )
346341 {
347342 var publicMembers = symbol . GetMembers ( )
348- . Where ( x => x is { DeclaredAccessibility : Accessibility . Public } and not IMethodSymbol { MethodKind : not MethodKind . Ordinary } )
343+ . Where ( x => x is { DeclaredAccessibility : Public } and not IMethodSymbol { MethodKind : not MethodKind . Ordinary } )
349344 . Select ( x => x . Name )
350345 . Distinct ( )
351346 . ToArray ( ) ;
@@ -404,7 +399,7 @@ static DerivedInput[] BuildDerivedStructs(Derived[] candidates, string interface
404399 if ( candidate . ImplementsUnionInterfaceFunc . Value ? . Invoke ( rootInterfaceFQN ) is not true )
405400 continue ;
406401
407- if ( ! HasHeaderInChain ( candidate , rootHeaderFQN ) )
402+ if ( candidate . HasHeaderInChainFunc . Value ? . Invoke ( rootHeaderFQN ) is not true )
408403 continue ;
409404
410405 if ( candidate . ForcedId is { } forcedId )
@@ -453,7 +448,7 @@ static DerivedInput[] BuildDerivedStructs(Derived[] candidates, string interface
453448 if ( candidate . ImplementsUnionInterfaceFunc . Value ? . Invoke ( interfaceFQN ) is not true )
454449 continue ;
455450
456- if ( ! HasHeaderInChain ( candidate , headerFQN ) )
451+ if ( candidate . HasHeaderInChainFunc . Value ? . Invoke ( headerFQN ) is not true )
457452 continue ;
458453
459454 var deferredInput = candidate . DeferredInputBuilderFunc . Value ? . Invoke ( ) ?? default ;
@@ -464,7 +459,6 @@ static DerivedInput[] BuildDerivedStructs(Derived[] candidates, string interface
464459 Name = candidate . DerivedName ,
465460 FQN = candidate . DerivedFQN ,
466461 Declaration = candidate . Declaration ,
467- HeaderFieldName = candidate . HeaderFieldName ,
468462 AssignedId = assigned . AssignedId ,
469463 PubliclyImplementedMembers = deferredInput . PublicMembers ,
470464 MemberNames = deferredInput . MemberNames ,
@@ -489,36 +483,34 @@ static DerivedInput[] BuildDerivedStructs(Derived[] candidates, string interface
489483 return result ;
490484 }
491485
492- static bool HasHeaderInChain ( Derived candidate , string headerFQN )
493- => candidate . HeaderChainFQNs . AsArray ( ) . Any ( x => x . Equals ( headerFQN , Ordinal ) ) ;
494-
495- static byte GetNextAvailableId ( HashSet < byte > usedIds , ref byte nextId )
496- {
497- while ( usedIds . Contains ( nextId ) )
498- nextId ++ ;
499-
500- return nextId ++ ;
501- }
502-
503- static EquatableArray < string > BuildHeaderChainFQNs ( INamedTypeSymbol symbol )
486+ static bool HasHeaderInChain ( INamedTypeSymbol symbol , string headerFQN )
504487 {
505488 var firstHeaderFieldType = GetFirstHeaderFieldType ( symbol ) ;
506489 if ( firstHeaderFieldType is null )
507- return [ ] ;
490+ return false ;
508491
509- var chain = new List < string > ( ) ;
510492 var visited = new HashSet < string > ( StringComparer . Ordinal ) ;
511493 var current = firstHeaderFieldType ;
512494 while ( current is not null && current . HasAttribute ( UnionHeaderStructAttributeFQN ) )
513495 {
514496 if ( ! visited . Add ( current . FQN ) )
515497 break ;
516498
517- chain . Add ( current . FQN ) ;
499+ if ( current . FQN . Equals ( headerFQN , Ordinal ) )
500+ return true ;
501+
518502 current = GetFirstHeaderFieldType ( current ) ;
519503 }
520504
521- return chain . ToArray ( ) ;
505+ return false ;
506+ }
507+
508+ static byte GetNextAvailableId ( HashSet < byte > usedIds , ref byte nextId )
509+ {
510+ while ( usedIds . Contains ( nextId ) )
511+ nextId ++ ;
512+
513+ return nextId ++ ;
522514 }
523515
524516 static IFieldSymbol ? GetFirstHeaderField ( INamedTypeSymbol symbol )
@@ -545,10 +537,10 @@ x is
545537 {
546538 Name : "Interface" ,
547539 TypeKind : TypeKind . Interface ,
548- DeclaredAccessibility : Accessibility . Public ,
540+ DeclaredAccessibility : Public ,
549541 }
550542 ) ??
551- typeMembers . FirstOrDefault ( x => x is { TypeKind : TypeKind . Interface , DeclaredAccessibility : Accessibility . Public } ) ;
543+ typeMembers . FirstOrDefault ( x => x is { TypeKind : TypeKind . Interface , DeclaredAccessibility : Public } ) ;
552544 }
553545
554546 static INamedTypeSymbol GetRootHeader ( INamedTypeSymbol headerSymbol , INamedTypeSymbol headerInterface )
@@ -611,10 +603,10 @@ static void GenerateSource(SourceProductionContext context, SourceWriter src, Ge
611603 using ( src . Braces )
612604 {
613605 if ( headerField . CanGet )
614- src . Line . Write ( $ "get => { derived . HeaderFieldName } .{ headerField . Name } ;") ;
606+ src . Line . Write ( $ "get => { m } UnsafeUtility.As< { derived . FQN } , { input . BaseTypeFQN } >(ref this) .{ headerField . Name } ;") ;
615607
616608 if ( headerField . CanSet )
617- src . Line . Write ( $ "set => { derived . HeaderFieldName } .{ headerField . Name } = value;") ;
609+ src . Line . Write ( $ "set => { m } UnsafeUtility.As< { derived . FQN } , { input . BaseTypeFQN } >(ref this) .{ headerField . Name } = value;") ;
618610 }
619611
620612 src . Linebreak ( ) ;
0 commit comments