|
| 1 | +// Copyright (c) Files Community |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +using Microsoft.CodeAnalysis; |
| 5 | + |
| 6 | +namespace Files.Core.SourceGenerator.Generators |
| 7 | +{ |
| 8 | + [Generator(LanguageNames.CSharp)] |
| 9 | + internal class VTableFunctionGenerator : IIncrementalGenerator |
| 10 | + { |
| 11 | + public void Initialize(IncrementalGeneratorInitializationContext context) |
| 12 | + { |
| 13 | + var sources = context.SyntaxProvider.ForAttributeWithMetadataName( |
| 14 | + "Files.Shared.Attributes.GeneratedVTableFunctionAttribute", |
| 15 | + static (node, token) => true, |
| 16 | + static (context, token) => context) |
| 17 | + .Collect(); |
| 18 | + |
| 19 | + context.RegisterSourceOutput(sources, (context, sources) => |
| 20 | + { |
| 21 | + var vtableFunctionsGroupedByStructs = sources.GroupBy(source => source.TargetSymbol.ContainingType, SymbolEqualityComparer.Default); |
| 22 | + |
| 23 | + foreach (var vtableFunctions in vtableFunctionsGroupedByStructs) |
| 24 | + { |
| 25 | + if (vtableFunctions.Key is not INamedTypeSymbol structSymbol || structSymbol.Name is not { } structName) |
| 26 | + continue; |
| 27 | + |
| 28 | + string vtableFunctionsCode = GenerateVtableFunctionsForStruct(structSymbol, vtableFunctions); |
| 29 | + context.AddSource($"{structName}_VTableFunctions.g.cs", vtableFunctionsCode); |
| 30 | + } |
| 31 | + }); |
| 32 | + } |
| 33 | + |
| 34 | + private string GenerateVtableFunctionsForStruct(INamedTypeSymbol structSymbol, IEnumerable<GeneratorAttributeSyntaxContext> sources) |
| 35 | + { |
| 36 | + StringBuilder builder = new(); |
| 37 | + |
| 38 | + builder.AppendLine($"// <auto-generated/>"); |
| 39 | + builder.AppendLine(); |
| 40 | + builder.AppendLine($"using global::System.Runtime.CompilerServices;"); |
| 41 | + builder.AppendLine(); |
| 42 | + builder.AppendLine($"#pragma warning disable"); |
| 43 | + builder.AppendLine(); |
| 44 | + |
| 45 | + if (structSymbol.ContainingNamespace is { IsGlobalNamespace: false }) |
| 46 | + { |
| 47 | + builder.AppendLine($"namespace {structSymbol.ContainingNamespace};"); |
| 48 | + builder.AppendLine(); |
| 49 | + } |
| 50 | + |
| 51 | + builder.AppendLine($"public unsafe partial struct {structSymbol.Name}"); |
| 52 | + builder.AppendLine($"{{"); |
| 53 | + |
| 54 | + builder.AppendLine($" private void** lpVtbl;"); |
| 55 | + builder.AppendLine(); |
| 56 | + |
| 57 | + var sourceIndex = 0; |
| 58 | + var sourceCount = sources.Count(); |
| 59 | + |
| 60 | + foreach (var source in sources) |
| 61 | + { |
| 62 | + var vtblIndex = source.Attributes[0].NamedArguments.Where(x => x.Key.Equals("Index")).FirstOrDefault().Value; |
| 63 | + var info = GetVTableFunctionInfo((IMethodSymbol)source.TargetSymbol); |
| 64 | + |
| 65 | + builder.AppendLine($" [global::System.Runtime.CompilerServices.MethodImpl(global::System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]"); |
| 66 | + |
| 67 | + builder.AppendLine($" public partial {info.ReturnType} {info.Name}({string.Join(", ", info.Parameters.Select(x => $"{x.Key} {x.Value}"))})"); |
| 68 | + builder.AppendLine($" {{"); |
| 69 | + builder.AppendLine($" return ({info.ReturnType})((delegate* unmanaged[MemberFunction]<{structSymbol.Name}*, {string.Join(", ", info.Parameters.Select(x => $"{x.Key}"))}, int>)(lpVtbl[{vtblIndex.Value}]))"); |
| 70 | + builder.AppendLine($" (({structSymbol.Name}*)global::System.Runtime.CompilerServices.Unsafe.AsPointer(ref this), {string.Join(", ", info.Parameters.Select(x => $"{x.Value}"))});"); |
| 71 | + builder.AppendLine($" }}"); |
| 72 | + |
| 73 | + if (sourceIndex < sourceCount - 1) |
| 74 | + builder.AppendLine(); |
| 75 | + |
| 76 | + sourceIndex++; |
| 77 | + } |
| 78 | + |
| 79 | + builder.AppendLine($"}}"); |
| 80 | + |
| 81 | + return builder.ToString(); |
| 82 | + } |
| 83 | + |
| 84 | + private VTableFunctionInfo GetVTableFunctionInfo(IMethodSymbol methodSymbol) |
| 85 | + { |
| 86 | + string functionName = methodSymbol.Name; |
| 87 | + string returnType = methodSymbol.ReturnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); |
| 88 | + |
| 89 | + Dictionary<string, string> parameters = []; |
| 90 | + foreach (var param in methodSymbol.Parameters) |
| 91 | + { |
| 92 | + var name = param.Name; |
| 93 | + var type = param.Type; |
| 94 | + |
| 95 | + parameters.Add(type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), name); |
| 96 | + } |
| 97 | + |
| 98 | + return new VTableFunctionInfo() |
| 99 | + { |
| 100 | + Name = functionName, |
| 101 | + ReturnType = returnType, |
| 102 | + Parameters = parameters, |
| 103 | + }; |
| 104 | + } |
| 105 | + } |
| 106 | +} |
0 commit comments