diff --git a/Stellar.SourceGenerators/ServiceRegistrationAttribute.cs b/Stellar.SourceGenerators/ServiceRegistrationAttribute.cs index 0540f44..bd9ed94 100644 --- a/Stellar.SourceGenerators/ServiceRegistrationAttribute.cs +++ b/Stellar.SourceGenerators/ServiceRegistrationAttribute.cs @@ -8,21 +8,21 @@ public class ServiceRegistrationAttribute : Attribute public Lifetime ServiceRegistrationType { get; set; } public bool RegisterInterfaces { get; set; } + + public Type? ServiceType { get; set; } + + public string? Key { get; set; } - public ServiceRegistrationAttribute() - { - ServiceRegistrationType = Lifetime.Transient; - } - - public ServiceRegistrationAttribute(Lifetime serviceRegistrationType) - { - ServiceRegistrationType = serviceRegistrationType; - } - - public ServiceRegistrationAttribute(Lifetime serviceRegistrationType = Lifetime.Transient, bool registerInterfaces = false) + public ServiceRegistrationAttribute( + Lifetime serviceRegistrationType = Lifetime.Transient, + bool registerInterfaces = false, + Type? serviceType = null, + string? key = null) { ServiceRegistrationType = serviceRegistrationType; RegisterInterfaces = registerInterfaces; + ServiceType = serviceType; + Key = key; } } diff --git a/Stellar.SourceGenerators/ServiceRegistrationGenerator.cs b/Stellar.SourceGenerators/ServiceRegistrationGenerator.cs index 2c58a3b..56ddbe4 100644 --- a/Stellar.SourceGenerators/ServiceRegistrationGenerator.cs +++ b/Stellar.SourceGenerators/ServiceRegistrationGenerator.cs @@ -1,5 +1,5 @@ -using System; using System.Collections.Generic; +using System.Collections.Immutable; using System.Linq; using System.Text; using Microsoft.CodeAnalysis; @@ -9,161 +9,274 @@ namespace Stellar.SourceGenerators { - [Generator] - public class ServiceRegistrationGenerator : ISourceGenerator + [Generator(LanguageNames.CSharp)] + public sealed class ServiceRegistrationGenerator : IIncrementalGenerator { - public void Initialize(GeneratorInitializationContext context) + public void Initialize(IncrementalGeneratorInitializationContext context) { - context.RegisterForSyntaxNotifications(() => new SyntaxReceiver()); - } - - public void Execute(GeneratorExecutionContext context) - { - // retrieve the populated receiver - if (context.SyntaxReceiver is not SyntaxReceiver receiver) - { - return; - } + var registrationGroups = context.SyntaxProvider + .CreateSyntaxProvider( + static (node, _) => node is ClassDeclarationSyntax { AttributeLists.Count: > 0 }, + static (syntaxContext, _) => GetRegistrationsForClass(syntaxContext)) + .Where(static group => !group.IsDefaultOrEmpty); - var compilation = context.Compilation; - var attributeSymbol = compilation.GetTypeByMetadataName("Stellar.ServiceRegistrationAttribute"); - if (attributeSymbol == null) - { - return; - } - - var registrations = new List(); - - // inspect each class with attributes - foreach (var candidate in receiver.CandidateClasses) - { - var model = compilation.GetSemanticModel(candidate.SyntaxTree); - if (model.GetDeclaredSymbol(candidate) is not INamedTypeSymbol classSymbol) - { - continue; - } + var registrations = registrationGroups + .SelectMany(static (group, _) => group); - // get ServiceRegistrationAttribute instances - var attrs = classSymbol.GetAttributes() - .Where(ad => SymbolEqualityComparer.Default.Equals(ad.AttributeClass, attributeSymbol)); + var combined = context.CompilationProvider.Combine(registrations.Collect()); - foreach (var ad in attrs) + context.RegisterSourceOutput( + combined, + static (spc, source) => { - // default values - int lifetimeValue = 0; // Transient - bool registerInterfaces = false; + Compilation compilation = source.Left; + ImmutableArray regs = source.Right; - // constructor args - if (ad.ConstructorArguments.Length >= 1 && ad.ConstructorArguments[0].Value is int lv) + if (regs.IsDefaultOrEmpty) { - lifetimeValue = lv; + return; } - if (ad.ConstructorArguments.Length == 2 && ad.ConstructorArguments[1].Value is bool ri) - { - registerInterfaces = ri; - } + string asmName = compilation.AssemblyName ?? "Registered"; + string cleanName = new string(asmName.Where(char.IsLetterOrDigit).ToArray()); + string methodName = $"AddRegisteredServicesFor{cleanName}"; + + var sb = new StringBuilder(); + sb.AppendLine("using Microsoft.Extensions.DependencyInjection;") + .AppendLine() + .AppendLine("namespace Stellar") + .AppendLine("{") + .AppendLine(" public static class RegisteredServiceRegistrations") + .AppendLine(" {") + .AppendLine($" public static IServiceCollection {methodName}(this IServiceCollection services)") + .AppendLine(" {"); - // named args - foreach (var named in ad.NamedArguments) + foreach (RegistrationInfo reg in regs) { - if (named.Key == nameof(Stellar.ServiceRegistrationAttribute.ServiceRegistrationType) && - named.Value.Value is int nlv) + string method = reg.LifetimeValue switch { - lifetimeValue = nlv; + (int)Lifetime.Scoped => "AddScoped", + (int)Lifetime.Singleton => "AddSingleton", + _ => "AddTransient", + }; + + string keyedMethod = reg.LifetimeValue switch + { + (int)Lifetime.Scoped => "AddKeyedScoped", + (int)Lifetime.Singleton => "AddKeyedSingleton", + _ => "AddKeyedTransient", + }; + + bool isKeyed = !string.IsNullOrEmpty(reg.Key); + string? keyLiteral = isKeyed ? $"\"{EscapeStringLiteral(reg.Key!)}\"" : null; + + if (reg.ExplicitServiceType is not null) + { + string serviceType = GetTypeSyntax(reg.ExplicitServiceType); + string implType = GetTypeSyntax(reg.ClassSymbol); + + sb.AppendLine( + isKeyed + ? $" services.{keyedMethod}(typeof({serviceType}), {keyLiteral}, typeof({implType}));" + : $" services.{method}(typeof({serviceType}), typeof({implType}));"); } - if (named.Key == nameof(Stellar.ServiceRegistrationAttribute.RegisterInterfaces) && - named.Value.Value is bool nri) + if (reg.RegisterInterfaces) { - registerInterfaces = nri; + foreach (INamedTypeSymbol iface in reg.ClassSymbol.Interfaces) + { + string serviceType = GetTypeSyntax(iface); + string implType = GetTypeSyntax(reg.ClassSymbol); + + sb.AppendLine( + isKeyed + ? $" services.{keyedMethod}(typeof({serviceType}), {keyLiteral}, typeof({implType}));" + : $" services.{method}(typeof({serviceType}), typeof({implType}));"); + } + } + + bool shouldSelfRegister = reg.ExplicitServiceType is null && !reg.RegisterInterfaces; + + if (shouldSelfRegister) + { + string selfType = GetTypeSyntax(reg.ClassSymbol); + + sb.AppendLine( + isKeyed + ? $" services.{keyedMethod}(typeof({selfType}), {keyLiteral}, typeof({selfType}));" + : $" services.{method}(typeof({selfType}), typeof({selfType}));"); } } - registrations.Add(new RegistrationInfo(classSymbol, lifetimeValue, registerInterfaces)); - } + sb.AppendLine(" return services;") + .AppendLine(" }") + .AppendLine(" }") + .AppendLine("}"); + + spc.AddSource($"{methodName}.g.cs", SourceText.From(sb.ToString(), Encoding.UTF8)); + }); + } + + private static ImmutableArray GetRegistrationsForClass(GeneratorSyntaxContext syntaxContext) + { + var classDecl = (ClassDeclarationSyntax)syntaxContext.Node; + + if (syntaxContext.SemanticModel.GetDeclaredSymbol(classDecl) is not INamedTypeSymbol classSymbol) + { + return ImmutableArray.Empty; } - // derive assembly-based extension method name - var asmName = context.Compilation.AssemblyName ?? "Registered"; - var cleanName = new string(asmName.Where(char.IsLetterOrDigit).ToArray()); - var methodName = $"AddRegisteredServicesFor{cleanName}"; + Compilation compilation = syntaxContext.SemanticModel.Compilation; - // build source - var sb = new StringBuilder(); - sb - .AppendLine("using Microsoft.Extensions.DependencyInjection;") - .AppendLine("namespace Stellar") - .AppendLine("{") - .AppendLine(" public static class RegisteredServiceRegistrations") - .AppendLine(" {") - .AppendLine($" public static IServiceCollection {methodName}(this IServiceCollection services)") - .AppendLine(" {"); - - foreach (var reg in registrations) + INamedTypeSymbol? attributeSymbol = + compilation.GetTypeByMetadataName("Stellar.ServiceRegistrationAttribute"); + + if (attributeSymbol is null) { - string method = reg.LifetimeValue switch + return ImmutableArray.Empty; + } + + var builder = ImmutableArray.CreateBuilder(); + + foreach (AttributeData ad in classSymbol.GetAttributes()) + { + if (!SymbolEqualityComparer.Default.Equals(ad.AttributeClass, attributeSymbol)) { - (int)Lifetime.Scoped => "AddScoped", - (int)Lifetime.Singleton => "AddSingleton", - _ => "AddTransient", - }; + continue; + } + + int lifetimeValue = (int)Lifetime.Transient; + bool registerInterfaces = false; + INamedTypeSymbol? explicitServiceType = null; + string? key = null; + + if (ad.ConstructorArguments.Length >= 1 && + ad.ConstructorArguments[0].Value is int lv) + { + lifetimeValue = lv; + } - if (reg.RegisterInterfaces) + if (ad.ConstructorArguments.Length >= 2 && + ad.ConstructorArguments[1].Value is bool ri) { - foreach (var iface in reg.ClassSymbol.Interfaces) + registerInterfaces = ri; + } + + foreach (KeyValuePair named in ad.NamedArguments) + { + switch (named.Key) { - sb.AppendLine( - $" services.{method}(typeof({iface.ToDisplayString()}), typeof({reg.ClassSymbol.ToDisplayString()}));"); + case nameof(ServiceRegistrationAttribute.ServiceRegistrationType) + when named.Value.Value is int nlv: + lifetimeValue = nlv; + break; + + case nameof(ServiceRegistrationAttribute.RegisterInterfaces) + when named.Value.Value is bool nri: + registerInterfaces = nri; + break; + + case nameof(ServiceRegistrationAttribute.ServiceType) + when named.Value.Value is ITypeSymbol typeSymbol: + explicitServiceType = typeSymbol as INamedTypeSymbol; + break; + + case nameof(ServiceRegistrationAttribute.Key) + when named.Value.Value is string s: + key = s; + break; } } - sb.AppendLine($" services.{method}(typeof({reg.ClassSymbol.ToDisplayString()}), typeof({reg.ClassSymbol.ToDisplayString()}));"); + builder.Add(new RegistrationInfo( + classSymbol: classSymbol, + explicitServiceType: explicitServiceType, + lifetimeValue: lifetimeValue, + registerInterfaces: registerInterfaces, + key: key)); } - sb - .AppendLine(" return services;") - .AppendLine(" }") - .AppendLine(" }") - .AppendLine("}"); - - context.AddSource($"{methodName}.g.cs", SourceText.From(sb.ToString(), Encoding.UTF8)); + return builder.ToImmutable(); } - private class SyntaxReceiver : ISyntaxReceiver + private sealed class RegistrationInfo( + INamedTypeSymbol classSymbol, + INamedTypeSymbol? explicitServiceType, + int lifetimeValue, + bool registerInterfaces, + string? key) { - public List CandidateClasses { get; } = new List(); + public INamedTypeSymbol ClassSymbol { get; } = classSymbol; + + public INamedTypeSymbol? ExplicitServiceType { get; } = explicitServiceType; + + public int LifetimeValue { get; } = lifetimeValue; + + public bool RegisterInterfaces { get; } = registerInterfaces; - public void OnVisitSyntaxNode(SyntaxNode syntaxNode) + public string? Key { get; } = key; + + public void Deconstruct(out INamedTypeSymbol classSymbol, out INamedTypeSymbol? explicitServiceType, out int lifetimeValue, out bool registerInterfaces, out string? key) { - // any class with attributes is a candidate - if (syntaxNode is ClassDeclarationSyntax cls && cls.AttributeLists.Count > 0) - { - CandidateClasses.Add(cls); - } + classSymbol = this.ClassSymbol; + explicitServiceType = this.ExplicitServiceType; + lifetimeValue = this.LifetimeValue; + registerInterfaces = this.RegisterInterfaces; + key = this.Key; } } - private class RegistrationInfo + private static string EscapeStringLiteral(string value) + { + return value + .Replace("\\", "\\\\") + .Replace("\"", "\\\""); + } + + private static string GetTypeSyntax(INamedTypeSymbol typeSymbol) { - public RegistrationInfo(INamedTypeSymbol classSymbol, int lifetimeValue, bool registerInterfaces) + var sb = new StringBuilder(); + + if (!typeSymbol.ContainingNamespace.IsGlobalNamespace) { - ClassSymbol = classSymbol; - LifetimeValue = lifetimeValue; - RegisterInterfaces = registerInterfaces; + sb + .Append(typeSymbol.ContainingNamespace.ToDisplayString()) + .Append('.'); } - public INamedTypeSymbol ClassSymbol { get; } + var containingTypes = new Stack(); + INamedTypeSymbol? current = typeSymbol.ContainingType; + while (current is not null) + { + containingTypes.Push(current); + current = current.ContainingType; + } - public int LifetimeValue { get; } + while (containingTypes.Count > 0) + { + AppendTypeName(sb, containingTypes.Pop()); + sb.Append('.'); + } - public bool RegisterInterfaces { get; } + AppendTypeName(sb, typeSymbol); - public void Deconstruct(out INamedTypeSymbol classSymbol, out int lifetimeValue, out bool registerInterfaces) + return sb.ToString(); + } + + private static void AppendTypeName(StringBuilder sb, INamedTypeSymbol typeSymbol) + { + sb.Append(typeSymbol.Name); + + if (typeSymbol.TypeParameters.Length > 0) { - classSymbol = this.ClassSymbol; - lifetimeValue = this.LifetimeValue; - registerInterfaces = this.RegisterInterfaces; + sb.Append('<'); + + if (typeSymbol.TypeParameters.Length > 1) + { + sb.Append(',', typeSymbol.TypeParameters.Length - 1); + } + + sb.Append('>'); } } }