22// Licensed under the MIT License.
33
44using Microsoft . CodeAnalysis ;
5+ using System . Reflection ;
56
67namespace Files . Core . SourceGenerator . Generators
78{
@@ -12,26 +13,60 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
1213 {
1314 var sources = context . SyntaxProvider . ForAttributeWithMetadataName (
1415 "Files.Shared.Attributes.GeneratedVTableFunctionAttribute" ,
15- static ( node , token ) => true ,
16- static ( context , token ) => context )
17- . Collect ( ) ;
16+ static ( node , token ) =>
17+ {
18+ token . ThrowIfCancellationRequested ( ) ;
19+
20+ // Check if the method has partial modifier and is public or internal (and not static)
21+ if ( node is not MethodDeclarationSyntax { AttributeLists . Count : > 0 } method ||
22+ ! method . Modifiers . Any ( SyntaxKind . PartialKeyword ) ||
23+ ! ( method . Modifiers . Any ( SyntaxKind . PublicKeyword ) || method . Modifiers . Any ( SyntaxKind . InternalKeyword ) ) ||
24+ method . Modifiers . Any ( SyntaxKind . StaticKeyword ) )
25+ return false ;
26+
27+ // Check if the type containing the method has partial modifier and is a struct
28+ if ( node . Parent is not TypeDeclarationSyntax { Keyword . RawKind : ( int ) SyntaxKind . StructKeyword , Modifiers : { } modifiers } ||
29+ ! modifiers . Any ( SyntaxKind . PartialKeyword ) )
30+ return false ;
31+
32+ return true ;
33+ } ,
34+ static ( context , token ) =>
35+ {
36+ token . ThrowIfCancellationRequested ( ) ;
37+
38+ var fullyQualifiedParentTypeName = context . TargetSymbol . ContainingType . ToString ( ) ;
39+ var structNamespace = context . TargetSymbol . ContainingType . ContainingNamespace . ToString ( ) ;
40+ var structName = context . TargetSymbol . ContainingType . Name ;
41+ var methodSymbol = ( IMethodSymbol ) context . TargetSymbol ;
42+ var functionName = methodSymbol . Name ;
43+ var returnTypeName = methodSymbol . ReturnType . ToDisplayString ( SymbolDisplayFormat . FullyQualifiedFormat ) ;
44+ var parameters = methodSymbol . Parameters . CastArray < ISymbol > ( ) ;
45+ var index = ( int ) context . Attributes [ 0 ] . NamedArguments . FirstOrDefault ( x => x . Key . Equals ( "Index" ) ) . Value . Value ! ;
46+
47+ return new VTableFunctionInfo ( fullyQualifiedParentTypeName , structNamespace , structName , functionName , returnTypeName , index , new ( parameters ) ) ;
48+ } )
49+ . Where ( static item => item is not null )
50+ . Collect ( )
51+ . Select ( ( items , token ) =>
52+ {
53+ token . ThrowIfCancellationRequested ( ) ;
54+
55+ return items . GroupBy ( source => source . FullyQualifiedParentTypeName , StringComparer . OrdinalIgnoreCase ) ;
56+ } ) ;
57+
1858
1959 context . RegisterSourceOutput ( sources , ( context , sources ) =>
2060 {
21- var vtableFunctionsGroupedByStructs = sources . GroupBy ( source => source . TargetSymbol . ContainingType , SymbolEqualityComparer . Default ) ;
22-
23- foreach ( var vtableFunctions in vtableFunctionsGroupedByStructs )
61+ foreach ( var source in sources )
2462 {
25- if ( vtableFunctions . Key is not INamedTypeSymbol structSymbol || structSymbol . Name is not { } structName )
26- continue ;
27-
28- string vtableFunctionsCode = GenerateVtableFunctionsForStruct ( structSymbol , vtableFunctions ) ;
29- context . AddSource ( $ "{ structName } _VTableFunctions.g.cs", vtableFunctionsCode ) ;
63+ string vtableFunctionsCode = GenerateVtableFunctionsForStruct ( source ) ;
64+ context . AddSource ( $ "{ source . Key } _VTableFunctions.g.cs", vtableFunctionsCode ) ;
3065 }
3166 } ) ;
3267 }
3368
34- private string GenerateVtableFunctionsForStruct ( INamedTypeSymbol structSymbol , IEnumerable < GeneratorAttributeSyntaxContext > sources )
69+ private string GenerateVtableFunctionsForStruct ( IEnumerable < VTableFunctionInfo > sources )
3570 {
3671 StringBuilder builder = new ( ) ;
3772
@@ -42,13 +77,10 @@ private string GenerateVtableFunctionsForStruct(INamedTypeSymbol structSymbol, I
4277 builder . AppendLine ( $ "#pragma warning disable") ;
4378 builder . AppendLine ( ) ;
4479
45- if ( structSymbol . ContainingNamespace is { IsGlobalNamespace : false } )
46- {
47- builder . AppendLine ( $ "namespace { structSymbol . ContainingNamespace } ;") ;
48- builder . AppendLine ( ) ;
49- }
80+ builder . AppendLine ( $ "namespace { sources . ElementAt ( 0 ) . ParentTypeNamespace } ;") ;
81+ builder . AppendLine ( ) ;
5082
51- builder . AppendLine ( $ "public unsafe partial struct { structSymbol . Name } ") ;
83+ builder . AppendLine ( $ "public unsafe partial struct { sources . ElementAt ( 0 ) . ParentTypeName } ") ;
5284 builder . AppendLine ( $ "{{") ;
5385
5486 builder . AppendLine ( $ " private void** lpVtbl;") ;
@@ -59,15 +91,14 @@ private string GenerateVtableFunctionsForStruct(INamedTypeSymbol structSymbol, I
5991
6092 foreach ( var source in sources )
6193 {
62- var vtblIndex = source . Attributes [ 0 ] . NamedArguments . Where ( x => x . Key . Equals ( "Index" ) ) . FirstOrDefault ( ) . Value ;
63- var info = GetVTableFunctionInfo ( ( IMethodSymbol ) source . TargetSymbol ) ;
94+ var parameters = source . Parameters . Cast < IParameterSymbol > ( ) . ToDictionary ( x => x . Type . ToDisplayString ( SymbolDisplayFormat . FullyQualifiedFormat ) , x => x . Name ) ;
6495
6596 builder . AppendLine ( $ " [global::System.Runtime.CompilerServices.MethodImpl(global::System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]") ;
6697
67- builder . AppendLine ( $ " public partial { info . ReturnType } { info . Name } ({ string . Join ( ", " , info . Parameters . Select ( x => $ "{ x . Key } { x . Value } ") ) } )") ;
98+ builder . AppendLine ( $ " public partial { source . ReturnTypeName } { source . Name } ({ string . Join ( ", " , parameters . Select ( x => $ "{ x . Key } { x . Value } ") ) } )") ;
6899 builder . AppendLine ( $ " {{") ;
69- builder . AppendLine ( $ " return ({ info . ReturnType } )((delegate* unmanaged[MemberFunction]<{ structSymbol . Name } *, { string . Join ( ", " , info . Parameters . Select ( x => $ "{ x . Key } ") ) } , int>)(lpVtbl[{ vtblIndex . Value } ]))") ;
70- builder . AppendLine ( $ " (({ structSymbol . Name } *)global::System.Runtime.CompilerServices.Unsafe.AsPointer(ref this), { string . Join ( ", " , info . Parameters . Select ( x => $ "{ x . Value } ") ) } );") ;
100+ builder . AppendLine ( $ " return ({ source . ReturnTypeName } )((delegate* unmanaged[MemberFunction]<{ sources . ElementAt ( 0 ) . ParentTypeName } *, { string . Join ( ", " , parameters . Select ( x => $ "{ x . Key } ") ) } , int>)(lpVtbl[{ source . Index } ]))") ;
101+ builder . AppendLine ( $ " (({ sources . ElementAt ( 0 ) . ParentTypeName } *)global::System.Runtime.CompilerServices.Unsafe.AsPointer(ref this), { string . Join ( ", " , parameters . Select ( x => $ "{ x . Value } ") ) } );") ;
71102 builder . AppendLine ( $ " }}") ;
72103
73104 if ( sourceIndex < sourceCount - 1 )
@@ -80,27 +111,5 @@ private string GenerateVtableFunctionsForStruct(INamedTypeSymbol structSymbol, I
80111
81112 return builder . ToString ( ) ;
82113 }
83-
84- private VTableFunctionInfo GetVTableFunctionInfo ( IMethodSymbol methodSymbol )
85- {
86- string functionName = methodSymbol . Name ;
87- string returnType = methodSymbol . ReturnType . ToDisplayString ( SymbolDisplayFormat . FullyQualifiedFormat ) ;
88-
89- Dictionary < string , string > parameters = [ ] ;
90- foreach ( var param in methodSymbol . Parameters )
91- {
92- var name = param . Name ;
93- var type = param . Type ;
94-
95- parameters . Add ( type . ToDisplayString ( SymbolDisplayFormat . FullyQualifiedFormat ) , name ) ;
96- }
97-
98- return new VTableFunctionInfo ( )
99- {
100- Name = functionName ,
101- ReturnType = returnType ,
102- Parameters = parameters ,
103- } ;
104- }
105114 }
106115}
0 commit comments