diff --git a/DependencyInjection.Attributed.sln b/DependencyInjection.Attributed.sln index 6472069..b502d97 100644 --- a/DependencyInjection.Attributed.sln +++ b/DependencyInjection.Attributed.sln @@ -7,12 +7,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Attributed", "src\Dependenc EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Attributed.Tests", "src\DependencyInjection.Attributed.Tests\Attributed.Tests.csproj", "{F2E67084-FED3-4E17-A012-0E8948FD3E06}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{FE116B5E-AE0C-4901-B0FE-BE41EF18EF06}" - ProjectSection(SolutionItems) = preProject - src\Directory.props = src\Directory.props - readme.md = readme.md - EndProjectSection -EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "CodeAnalysis.Tests", "src\CodeAnalysis.Tests\CodeAnalysis.Tests.csproj", "{E512DEBA-FB35-47FD-AF25-3BAECCF667B1}" EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Samples", "Samples", "{3C5A7AC8-E8CC-40D6-B472-A693F742152A}" diff --git a/readme.md b/readme.md index a7482a3..eb1f63b 100644 --- a/readme.md +++ b/readme.md @@ -63,6 +63,31 @@ And that's it. The source generator will discover annotated types in the current project and all its references too. Since the registration code is generated at compile-time, there is no run-time reflection (or dependencies) whatsoever. +You can also avoid attributes entirely by using a convention-based approach, which +is nevertheless still compile-time checked and source-generated. This allows +registering services for which you don't even have the source code to annotate: + +```csharp +var builder = WebApplication.CreateBuilder(args); + +builder.Services.AddServices(typeof(IRepository), ServiceLifetime.Scoped); +// ... +``` + +You can also use a regular expression to match services by name instead: + +```csharp +var builder = WebApplication.CreateBuilder(args); + +builder.Services.AddServices(".*Service$"); // defaults to ServiceLifetime.Singleton +// ... +``` + +Or a combination of both, as needed. In all cases, NO run-time reflection is +ever performed, and the compile-time source generator will evaluate the types +that are assignable to the given type or matching full type names and emit +the typed registrations as needed. + ### Keyed Services [Keyed services](https://learn.microsoft.com/en-us/aspnet/core/fundamentals/dependency-injection?view=aspnetcore-8.0#keyed-services) @@ -105,6 +130,8 @@ right `INotificationService` will be injected, based on the key provided. Note you can also register the same service using multiple keys, as shown in the `EmailNotificationService` above. +> Keyed services are a feature of version 8.0+ of Microsoft.Extensions.DependencyInjection + ## How It Works The generated code that implements the registration looks like the following: diff --git a/src/CodeAnalysis.Tests/AddServicesAnalyzerTests.cs b/src/CodeAnalysis.Tests/AddServicesAnalyzerTests.cs index 7615e18..d307d60 100644 --- a/src/CodeAnalysis.Tests/AddServicesAnalyzerTests.cs +++ b/src/CodeAnalysis.Tests/AddServicesAnalyzerTests.cs @@ -15,7 +15,7 @@ namespace Tests.CodeAnalysis; -public record AddServicesAnalyzerTests(ITestOutputHelper Output) +public class AddServicesAnalyzerTests(ITestOutputHelper Output) { [Fact] public async Task NoWarningIfAddServicesPresent() @@ -41,13 +41,19 @@ public static void Main() """, TestState = { + Sources = + { + ThisAssembly.Resources.AttributedServicesExtension.Text, + ThisAssembly.Resources.ServiceAttribute.Text, + ThisAssembly.Resources.ServiceAttribute_1.Text, + }, ReferenceAssemblies = new ReferenceAssemblies( - "net6.0", + "net8.0", new PackageIdentity( - "Microsoft.NETCore.App.Ref", "6.0.0"), - Path.Combine("ref", "net6.0")) + "Microsoft.NETCore.App.Ref", "8.0.0"), + Path.Combine("ref", "net8.0")) .AddPackages(ImmutableArray.Create( - new PackageIdentity("Microsoft.Extensions.DependencyInjection", "6.0.0"))) + new PackageIdentity("Microsoft.Extensions.DependencyInjection", "8.0.0"))) }, }; @@ -81,13 +87,19 @@ public static void Main() """, TestState = { + Sources = + { + ThisAssembly.Resources.AttributedServicesExtension.Text, + ThisAssembly.Resources.ServiceAttribute.Text, + ThisAssembly.Resources.ServiceAttribute_1.Text, + }, ReferenceAssemblies = new ReferenceAssemblies( - "net6.0", + "net8.0", new PackageIdentity( - "Microsoft.NETCore.App.Ref", "6.0.0"), - Path.Combine("ref", "net6.0")) + "Microsoft.NETCore.App.Ref", "8.0.0"), + Path.Combine("ref", "net8.0")) .AddPackages(ImmutableArray.Create( - new PackageIdentity("Microsoft.Extensions.DependencyInjection.Abstractions", "6.0.0"))) + new PackageIdentity("Microsoft.Extensions.DependencyInjection", "8.0.0"))) }, }; @@ -116,6 +128,12 @@ public static void Main() """, TestState = { + Sources = + { + ThisAssembly.Resources.AttributedServicesExtension.Text, + ThisAssembly.Resources.ServiceAttribute.Text, + ThisAssembly.Resources.ServiceAttribute_1.Text, + }, ReferenceAssemblies = new ReferenceAssemblies( "net8.0", new PackageIdentity( @@ -157,6 +175,12 @@ public static void Main() """, TestState = { + Sources = + { + ThisAssembly.Resources.AttributedServicesExtension.Text, + ThisAssembly.Resources.ServiceAttribute.Text, + ThisAssembly.Resources.ServiceAttribute_1.Text, + }, ReferenceAssemblies = new ReferenceAssemblies( "net8.0", new PackageIdentity( @@ -173,8 +197,8 @@ public static void Main() await test.RunAsync(); } - class GeneratorsTest : CSharpSourceGeneratorTest + class GeneratorsTest : CSharpSourceGeneratorTest { - protected override IEnumerable GetSourceGenerators() => base.GetSourceGenerators().Concat([typeof(IncrementalGenerator)]); + //protected override IEnumerable GetSourceGenerators() => base.GetSourceGenerators().Concat([typeof(IncrementalGenerator)]); } } diff --git a/src/CodeAnalysis.Tests/CodeAnalysis.Tests.csproj b/src/CodeAnalysis.Tests/CodeAnalysis.Tests.csproj index 11fa2a4..f01ce1f 100644 --- a/src/CodeAnalysis.Tests/CodeAnalysis.Tests.csproj +++ b/src/CodeAnalysis.Tests/CodeAnalysis.Tests.csproj @@ -1,4 +1,4 @@ - + net6.0 @@ -13,10 +13,13 @@ + + + diff --git a/src/CodeAnalysis.Tests/ConventionAnalyzerTests.cs b/src/CodeAnalysis.Tests/ConventionAnalyzerTests.cs new file mode 100644 index 0000000..7628b7d --- /dev/null +++ b/src/CodeAnalysis.Tests/ConventionAnalyzerTests.cs @@ -0,0 +1,156 @@ +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.IO; +using System.Linq; +using System.Threading.Tasks; +using Devlooped.Extensions.DependencyInjection.Attributed; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Testing; +using Microsoft.CodeAnalysis.Testing; +using Xunit; +using Xunit.Abstractions; +using AnalyzerTest = Microsoft.CodeAnalysis.CSharp.Testing.CSharpAnalyzerTest; +using Verifier = Microsoft.CodeAnalysis.CSharp.Testing.CSharpAnalyzerVerifier; + +namespace Tests.CodeAnalysis; + +public class ConventionAnalyzerTests(ITestOutputHelper Output) +{ + [Fact] + public async Task ErrorIfNonTypeOf() + { + var test = new AnalyzerTest + { + TestBehaviors = TestBehaviors.SkipGeneratedSourcesCheck, + TestCode = + """ + using System; + using Microsoft.Extensions.DependencyInjection; + + public static class Program + { + public static void Main() + { + var services = new ServiceCollection(); + var type = typeof(IDisposable); + services.AddServices({|#0:type|}); + } + } + """, + TestState = + { + Sources = + { + ThisAssembly.Resources.AttributedServicesExtension.Text, + ThisAssembly.Resources.ServiceAttribute.Text, + ThisAssembly.Resources.ServiceAttribute_1.Text, + }, + ReferenceAssemblies = new ReferenceAssemblies( + "net8.0", + new PackageIdentity( + "Microsoft.NETCore.App.Ref", "8.0.0"), + Path.Combine("ref", "net8.0")) + .AddPackages(ImmutableArray.Create( + new PackageIdentity("Microsoft.Extensions.DependencyInjection", "8.0.0"))) + }, + }; + + var expected = Verifier.Diagnostic(ConventionsAnalyzer.AssignableTypeOfRequired).WithLocation(0); + test.ExpectedDiagnostics.Add(expected); + + await test.RunAsync(); + } + + [Fact] + public async Task NoErrorOnTypeOfAndLifetime() + { + var test = new AnalyzerTest + { + TestBehaviors = TestBehaviors.SkipGeneratedSourcesCheck, + TestCode = + """ + using System; + using Microsoft.Extensions.DependencyInjection; + + public static class Program + { + public static void Main() + { + var services = new ServiceCollection(); + services.AddServices(typeof(IDisposable), ServiceLifetime.Scoped); + } + } + """, + TestState = + { + Sources = + { + ThisAssembly.Resources.AttributedServicesExtension.Text, + ThisAssembly.Resources.ServiceAttribute.Text, + ThisAssembly.Resources.ServiceAttribute_1.Text, + }, + ReferenceAssemblies = new ReferenceAssemblies( + "net8.0", + new PackageIdentity( + "Microsoft.NETCore.App.Ref", "8.0.0"), + Path.Combine("ref", "net8.0")) + .AddPackages(ImmutableArray.Create( + new PackageIdentity("Microsoft.Extensions.DependencyInjection", "8.0.0"))) + }, + }; + + //var expected = Verifier.Diagnostic(ConventionsAnalyzer.AssignableTypeOfRequired).WithLocation(0); + //test.ExpectedDiagnostics.Add(expected); + + await test.RunAsync(); + } + + [Fact] + public async Task WarnIfOpenGeneric() + { + var test = new AnalyzerTest + { + TestBehaviors = TestBehaviors.SkipGeneratedSourcesCheck, + TestCode = + """ + using System; + using Microsoft.Extensions.DependencyInjection; + + public interface IRepository { } + public class Repository : IRepository { } + + public static class Program + { + public static void Main() + { + var services = new ServiceCollection(); + services.AddServices({|#0:typeof(Repository<>)|}, ServiceLifetime.Scoped); + } + } + """, + TestState = + { + Sources = + { + ThisAssembly.Resources.AttributedServicesExtension.Text, + ThisAssembly.Resources.ServiceAttribute.Text, + ThisAssembly.Resources.ServiceAttribute_1.Text, + }, + ReferenceAssemblies = new ReferenceAssemblies( + "net8.0", + new PackageIdentity( + "Microsoft.NETCore.App.Ref", "8.0.0"), + Path.Combine("ref", "net8.0")) + .AddPackages(ImmutableArray.Create( + new PackageIdentity("Microsoft.Extensions.DependencyInjection", "8.0.0"))) + }, + }; + + var expected = Verifier.Diagnostic(ConventionsAnalyzer.OpenGenericType).WithLocation(0); + test.ExpectedDiagnostics.Add(expected); + + await test.RunAsync(); + } + +} diff --git a/src/DependencyInjection.Attributed.Tests/Attributed.Tests.csproj b/src/DependencyInjection.Attributed.Tests/Attributed.Tests.csproj index ff877b3..2578277 100644 --- a/src/DependencyInjection.Attributed.Tests/Attributed.Tests.csproj +++ b/src/DependencyInjection.Attributed.Tests/Attributed.Tests.csproj @@ -1,4 +1,4 @@ - + @@ -8,6 +8,18 @@ Tests + + + + + + + + + + + + @@ -28,6 +40,7 @@ + diff --git a/src/DependencyInjection.Attributed.Tests/ContentFiles.targets b/src/DependencyInjection.Attributed.Tests/ContentFiles.targets new file mode 100644 index 0000000..a651d7f --- /dev/null +++ b/src/DependencyInjection.Attributed.Tests/ContentFiles.targets @@ -0,0 +1,16 @@ + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/DependencyInjection.Attributed.Tests/ConventionsTests.cs b/src/DependencyInjection.Attributed.Tests/ConventionsTests.cs new file mode 100644 index 0000000..653e2f8 --- /dev/null +++ b/src/DependencyInjection.Attributed.Tests/ConventionsTests.cs @@ -0,0 +1,68 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel.Composition; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; + +namespace Tests.DependencyInjection; + +public class ConventionsTests(ITestOutputHelper Output) +{ + [Fact] + public void RegisterRepositoryServices() + { + var conventions = new ServiceCollection(); + conventions.AddSingleton(Output); + conventions.AddServices(typeof(IRepository)); + var services = conventions.BuildServiceProvider(); + + var instance = services.GetServices().ToList(); + + Assert.Equal(2, instance.Count); + } + + [Fact] + public void RegisterServiceByRegex() + { + var conventions = new ServiceCollection(); + conventions.AddSingleton(Output); + conventions.AddServices(nameof(ConventionsTests), ServiceLifetime.Transient); + var services = conventions.BuildServiceProvider(); + + var instance = services.GetRequiredService(); + var instance2 = services.GetRequiredService(); + + Assert.NotSame(instance, instance2); + } + + [Fact] + public void RegisterGenericServices() + { + var conventions = new ServiceCollection(); + + conventions.AddServices(typeof(IGenericRepository<>), ServiceLifetime.Scoped); + + var services = conventions.BuildServiceProvider(); + + var scope = services.CreateScope(); + var instance = scope.ServiceProvider.GetRequiredService>(); + var instance2 = scope.ServiceProvider.GetRequiredService>(); + + Assert.NotNull(instance); + Assert.NotNull(instance2); + + Assert.Same(instance, scope.ServiceProvider.GetRequiredService>()); + Assert.Same(instance2, scope.ServiceProvider.GetRequiredService>()); + } +} + +public interface IRepository { } +public class FooRepository : IRepository { } +public class BarRepository : IRepository { } + +public interface IGenericRepository { } +public class FooGenericRepository : IGenericRepository { } +public class BarGenericRepository : IGenericRepository { } diff --git a/src/DependencyInjection.Attributed/AddServicesAnalyzer.cs b/src/DependencyInjection.Attributed/AddServicesAnalyzer.cs index a7450ec..2d4a394 100644 --- a/src/DependencyInjection.Attributed/AddServicesAnalyzer.cs +++ b/src/DependencyInjection.Attributed/AddServicesAnalyzer.cs @@ -48,7 +48,9 @@ public override void Initialize(AnalysisContext context) .DescendantNodes() .OfType() .Select(invocation => new { Invocation = invocation, semantic.GetSymbolInfo(invocation, semanticContext.CancellationToken).Symbol }) - .Where(x => x.Symbol is IMethodSymbol) + // We don't consider invocations from methods that have the DDIAddServicesAttribute as user-provided, since we do that + // in our type/regex overloads. Users need to invoke those methods in turn. + .Where(x => x.Symbol is IMethodSymbol method && !method.GetAttributes().Any(attr => attr.AttributeClass?.Name == "DDIAddServicesAttribute")) .Select(x => new { x.Invocation, Method = (IMethodSymbol)x.Symbol! }); bool IsServiceCollectionExtension(IMethodSymbol method) => method.IsExtensionMethod && diff --git a/src/DependencyInjection.Attributed/AddServicesExtension.cs b/src/DependencyInjection.Attributed/AddServicesExtension.cs deleted file mode 100644 index 50d29aa..0000000 --- a/src/DependencyInjection.Attributed/AddServicesExtension.cs +++ /dev/null @@ -1,66 +0,0 @@ -using System; -using System.ComponentModel; -using Microsoft.Extensions.DependencyInjection; - -namespace Devlooped.Extensions.DependencyInjection.Attributed -{ - /// - /// Contains the extension methods to register - /// compile-time discovered services to an . - /// - [EditorBrowsable(EditorBrowsableState.Never)] - static partial class AddServicesExtension - { - /// - /// Adds the automatically discovered services that were annotated with a . - /// - /// The to add the services to. - /// The so that additional calls can be chained. - [DDIAddServices] - public static IServiceCollection AddServices(this IServiceCollection services) - { - AddScopedServices(services); - AddSingletonServices(services); - AddTransientServices(services); - - AddKeyedScopedServices(services); - AddKeyedSingletonServices(services); - AddKeyedTransientServices(services); - - return services; - } - - /// - /// Adds discovered scoped services to the collection. - /// - static partial void AddScopedServices(IServiceCollection services); - - /// - /// Adds discovered singleton services to the collection. - /// - static partial void AddSingletonServices(IServiceCollection services); - - /// - /// Adds discovered transient services to the collection. - /// - static partial void AddTransientServices(IServiceCollection services); - - /// - /// Adds discovered keyed scoped services to the collection. - /// - static partial void AddKeyedScopedServices(IServiceCollection services); - - /// - /// Adds discovered keyed singleton services to the collection. - /// - static partial void AddKeyedSingletonServices(IServiceCollection services); - - /// - /// Adds discovered keyed transient services to the collection. - /// - static partial void AddKeyedTransientServices(IServiceCollection services); - - [AttributeUsage(AttributeTargets.Method)] - class DDIAddServicesAttribute : Attribute { } - } -} \ No newline at end of file diff --git a/src/DependencyInjection.Attributed/Attributed.csproj b/src/DependencyInjection.Attributed/Attributed.csproj index 74fce66..2882298 100644 --- a/src/DependencyInjection.Attributed/Attributed.csproj +++ b/src/DependencyInjection.Attributed/Attributed.csproj @@ -11,11 +11,22 @@ true true $(DefineConstants);DDI_ADDSERVICE + false + + + + + + + + + + @@ -23,15 +34,8 @@ - - - - - - - diff --git a/src/DependencyInjection.Attributed/AttributedServicesExtension.cs b/src/DependencyInjection.Attributed/AttributedServicesExtension.cs new file mode 100644 index 0000000..99fc7e8 --- /dev/null +++ b/src/DependencyInjection.Attributed/AttributedServicesExtension.cs @@ -0,0 +1,124 @@ +using System; +using System.ComponentModel; + +namespace Microsoft.Extensions.DependencyInjection +{ + /// + /// Contains the extension methods to register + /// compile-time discovered services to an . + /// + [EditorBrowsable(EditorBrowsableState.Never)] + static partial class AttributedServicesExtension + { + static readonly ServiceDescriptor servicesAddedDescriptor = new ServiceDescriptor(typeof(DDIAddServicesAttribute), _ => new DDIAddServicesAttribute(), ServiceLifetime.Singleton); + + /// + /// Adds the services that are assignable to to the collection, + /// in addition to the discovered services that were annotated with a . + /// + /// + /// Note that NO runtime reflection is performed when using this method. A compile-time source + /// generator will emit the relevant registration methods for all matching types (in the current + /// assembly or any referenced assemblies) at build time, resulting in maximum startup performance + /// as well as AOT-safety. + /// + /// The to add the services to. + /// The type that services must be assignable to in order to be registered. + /// The service lifetime to register. + /// The so that additional calls can be chained. + [DDIAddServices] + public static IServiceCollection AddServices(this IServiceCollection services, Type assignableTo, ServiceLifetime lifetime = ServiceLifetime.Singleton) => services.AddServices(); + + /// + /// Adds the services that are assignable to to the collection, + /// in addition to the discovered services that were annotated with a . + /// + /// + /// Note that NO runtime reflection is performed when using this method. A compile-time source + /// generator will emit the relevant registration methods for all matching types (in the current + /// assembly or any referenced assemblies) at build time, resulting in maximum startup performance + /// as well as AOT-safety. + /// + /// The to add the services to. + /// Regular expression to match against the full name of the type to determine if it should be registered as a service. + /// The service lifetime to register. + /// The so that additional calls can be chained. + [DDIAddServices] + public static IServiceCollection AddServices(this IServiceCollection services, string fullNameExpression, ServiceLifetime lifetime = ServiceLifetime.Singleton) => services.AddServices(); + + /// + /// Adds the services that are assignable to to the collection, + /// in addition to the discovered services that were annotated with a . + /// + /// + /// Note that NO runtime reflection is performed when using this method. A compile-time source + /// generator will emit the relevant registration methods for all matching types (in the current + /// assembly or any referenced assemblies) at build time, resulting in maximum startup performance + /// as well as AOT-safety. + /// + /// The to add the services to. + /// The type that services must be assignable to in order to be registered. + /// Regular expression to match against the full name of the type to determine if it should be registered as a service, in addition to being assignable to . + /// The service lifetime to register. + /// The so that additional calls can be chained. + [DDIAddServices] + public static IServiceCollection AddServices(this IServiceCollection services, Type assignableTo, string fullNameExpression, ServiceLifetime lifetime = ServiceLifetime.Singleton) => services.AddServices(); + + /// + /// Adds the automatically discovered services that were annotated with a . + /// + /// The to add the services to. + /// The so that additional calls can be chained. + [DDIAddServices] + public static IServiceCollection AddServices(this IServiceCollection services) + { + if (services.Contains(servicesAddedDescriptor)) + return services; + + AddScopedServices(services); + AddSingletonServices(services); + AddTransientServices(services); + + AddKeyedScopedServices(services); + AddKeyedSingletonServices(services); + AddKeyedTransientServices(services); + + services.Add(servicesAddedDescriptor); + + return services; + } + + /// + /// Adds discovered scoped services to the collection. + /// + static partial void AddScopedServices(IServiceCollection services); + + /// + /// Adds discovered singleton services to the collection. + /// + static partial void AddSingletonServices(IServiceCollection services); + + /// + /// Adds discovered transient services to the collection. + /// + static partial void AddTransientServices(IServiceCollection services); + + /// + /// Adds discovered keyed scoped services to the collection. + /// + static partial void AddKeyedScopedServices(IServiceCollection services); + + /// + /// Adds discovered keyed singleton services to the collection. + /// + static partial void AddKeyedSingletonServices(IServiceCollection services); + + /// + /// Adds discovered keyed transient services to the collection. + /// + static partial void AddKeyedTransientServices(IServiceCollection services); + + [AttributeUsage(AttributeTargets.Method)] + class DDIAddServicesAttribute : Attribute { } + } +} \ No newline at end of file diff --git a/src/DependencyInjection.Attributed/ConventionsAnalyzer.cs b/src/DependencyInjection.Attributed/ConventionsAnalyzer.cs new file mode 100644 index 0000000..9b0b248 --- /dev/null +++ b/src/DependencyInjection.Attributed/ConventionsAnalyzer.cs @@ -0,0 +1,79 @@ +using System.Collections.Immutable; +using System.Diagnostics; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Diagnostics; + +namespace Devlooped.Extensions.DependencyInjection.Attributed; + +[DiagnosticAnalyzer(LanguageNames.CSharp, LanguageNames.VisualBasic)] +public class ConventionsAnalyzer : DiagnosticAnalyzer +{ + public static DiagnosticDescriptor AssignableTypeOfRequired { get; } = + new DiagnosticDescriptor( + "DDI002", + "The convention-based registration requires a typeof() expression.", + "When registering services by type, typeof() must be used exclusively to avoid run-time reflection.", + "Build", + DiagnosticSeverity.Error, + isEnabledByDefault: true); + + public static DiagnosticDescriptor OpenGenericType { get; } = + new DiagnosticDescriptor( + "DDI003", + "Open generic service implementations are not supported for convention-based registration.", + "Only the concrete (closed) implementations of the open generic interface will be registered Register open generic services explicitly using the built-in service collection methods.", + "Build", + DiagnosticSeverity.Warning, + isEnabledByDefault: true); + + public override ImmutableArray SupportedDiagnostics { get; } = + ImmutableArray.Create(AssignableTypeOfRequired, OpenGenericType); + + public override void Initialize(AnalysisContext context) + { + if (!Debugger.IsAttached) + context.EnableConcurrentExecution(); + + context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None); + + context.RegisterCompilationStartAction(startContext => + { + var servicesCollection = startContext.Compilation.GetTypeByMetadataName("Microsoft.Extensions.DependencyInjection.IServiceCollection"); + if (servicesCollection == null) + return; + + startContext.RegisterSemanticModelAction(semanticContext => + { + var semantic = semanticContext.SemanticModel; + var invocations = semantic.SyntaxTree + .GetRoot(semanticContext.CancellationToken) + .DescendantNodes() + .OfType() + .Select(invocation => new { Invocation = invocation, semantic.GetSymbolInfo(invocation, semanticContext.CancellationToken).Symbol }) + .Where(x => x.Symbol is IMethodSymbol method && + method.GetAttributes().Any(attr => attr.AttributeClass?.Name == "DDIAddServicesAttribute") && + // This signals the convention overloads that take a type, regex and lifetime. + method.Parameters.Length > 1) + .Select(x => new { x.Invocation, Method = (IMethodSymbol)x.Symbol! }); + + foreach (var invocation in invocations) + { + for (var i = 0; i < invocation.Invocation.ArgumentList.Arguments.Count; i++) + { + var arg = invocation.Invocation.ArgumentList.Arguments[i]; + var prm = invocation.Method.Parameters[i]; + if (prm.Type.Name == "Type" && prm.Type.ContainingNamespace.Name == "System") + { + if (arg.Expression is not TypeOfExpressionSyntax typeExpr) + semanticContext.ReportDiagnostic(Diagnostic.Create(AssignableTypeOfRequired, arg.GetLocation())); + else if (semantic.GetSymbolInfo(typeExpr.Type).Symbol is INamedTypeSymbol argType && argType.IsGenericType && argType.IsUnboundGenericType) + semanticContext.ReportDiagnostic(Diagnostic.Create(OpenGenericType, arg.GetLocation())); + } + } + } + }); + }); + } +} \ No newline at end of file diff --git a/src/DependencyInjection.Attributed/Devlooped.Extensions.DependencyInjection.Attributed.props b/src/DependencyInjection.Attributed/Devlooped.Extensions.DependencyInjection.Attributed.props index 7209539..3540b84 100644 --- a/src/DependencyInjection.Attributed/Devlooped.Extensions.DependencyInjection.Attributed.props +++ b/src/DependencyInjection.Attributed/Devlooped.Extensions.DependencyInjection.Attributed.props @@ -7,8 +7,6 @@ - - \ No newline at end of file diff --git a/src/DependencyInjection.Attributed/IncrementalGenerator.cs b/src/DependencyInjection.Attributed/IncrementalGenerator.cs index cd1e93d..59ef674 100644 --- a/src/DependencyInjection.Attributed/IncrementalGenerator.cs +++ b/src/DependencyInjection.Attributed/IncrementalGenerator.cs @@ -3,9 +3,11 @@ using System.Collections.Immutable; using System.Linq; using System.Text; +using System.Text.RegularExpressions; using System.Threading; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Diagnostics; using KeyedService = (Microsoft.CodeAnalysis.INamedTypeSymbol Type, Microsoft.CodeAnalysis.TypedConstant? Key); @@ -18,7 +20,13 @@ namespace Devlooped.Extensions.DependencyInjection.Attributed; [Generator(LanguageNames.CSharp)] public class IncrementalGenerator : IIncrementalGenerator { - record ServiceSymbol(INamedTypeSymbol Type, TypedConstant? Key, int Lifetime); + record ServiceSymbol(INamedTypeSymbol Type, int Lifetime, TypedConstant? Key); + record ServiceRegistration(int Lifetime, INamedTypeSymbol? AssignableTo, string? FullNameExpression) + { + Regex? regex; + + public Regex Regex => (regex ??= FullNameExpression is not null ? new(FullNameExpression) : new(".*")); + } public void Initialize(IncrementalGeneratorInitializationContext context) { @@ -34,7 +42,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) symbol?.Accept(visitor); } - return visitor.TypeSymbols; + return visitor.TypeSymbols.Where(t => !t.IsAbstract && t.TypeKind == TypeKind.Class); }); bool IsService(AttributeData attr) => @@ -59,7 +67,8 @@ bool IsExport(AttributeData attr) // NOTE: we recognize the attribute by name, not precise type. This makes the generator // more flexible and avoids requiring any sort of run-time dependency. - var services = types + + var attributedServices = types .SelectMany((x, _) => { var name = x.Name; @@ -115,7 +124,7 @@ bool IsExport(AttributeData attr) } } - services.Add(new(x, key, lifetime)); + services.Add(new(x, lifetime, key)); } return services.ToImmutableArray(); @@ -127,16 +136,57 @@ bool IsExport(AttributeData attr) // Only requisite is that we define Scoped = 0, Singleton = 1 and Transient = 2. // This matches https://learn.microsoft.com/en-us/dotnet/api/microsoft.extensions.dependencyinjection.servicelifetime?view=dotnet-plat-ext-6.0#fields + // Add conventional registrations. + + // First get all AddServices(type, regex, lifetime) invocations. + var methodInvocations = context.SyntaxProvider + .CreateSyntaxProvider( + predicate: static (node, _) => node is InvocationExpressionSyntax, + transform: static (ctx, _) => GetServiceRegistration((InvocationExpressionSyntax)ctx.Node, ctx.SemanticModel)) + .Where(details => details != null) + .Collect(); + + // Project matching service types to register with the given lifetime. + var conventionServices = types.Combine(methodInvocations.Combine(context.CompilationProvider)).SelectMany((pair, cancellationToken) => + { + var (typeSymbol, (registrations, compilation)) = pair; + var results = ImmutableArray.CreateBuilder(); + + foreach (var registration in registrations) + { + // check of typeSymbol is assignable (is the same type, inherits from it or implements if its an interface) to registration.AssignableTo + if (registration!.AssignableTo is not null && !typeSymbol.Is(registration.AssignableTo)) + continue; + + if (registration!.FullNameExpression != null && !registration.Regex.IsMatch(typeSymbol.ToFullName(compilation))) + continue; + + results.Add(new ServiceSymbol(typeSymbol, registration.Lifetime, null)); + } + + return results.ToImmutable(); + }); + + // Flatten and remove duplicates + var finalServices = attributedServices.Collect().Combine(conventionServices.Collect()) + .SelectMany((tuple, _) => ImmutableArray.CreateRange([tuple.Item1, tuple.Item2])) + .SelectMany((items, _) => items.Distinct().ToImmutableArray()); + + RegisterServicesOutput(context, finalServices, options); + } + + void RegisterServicesOutput(IncrementalGeneratorInitializationContext context, IncrementalValuesProvider services, IncrementalValueProvider<(AnalyzerConfigOptionsProvider Left, Compilation Right)> options) + { context.RegisterImplementationSourceOutput( - services.Where(x => x!.Lifetime == 0 && x.Key is null).Select((x, _) => new KeyedService(x!.Type, x.Key!)).Collect().Combine(options), + services.Where(x => x!.Lifetime == 0 && x.Key is null).Select((x, _) => new KeyedService(x!.Type, null)).Collect().Combine(options), (ctx, data) => AddPartial("AddSingleton", ctx, data)); context.RegisterImplementationSourceOutput( - services.Where(x => x!.Lifetime == 1 && x.Key is null).Select((x, _) => new KeyedService(x!.Type, x.Key!)).Collect().Combine(options), + services.Where(x => x!.Lifetime == 1 && x.Key is null).Select((x, _) => new KeyedService(x!.Type, null)).Collect().Combine(options), (ctx, data) => AddPartial("AddScoped", ctx, data)); context.RegisterImplementationSourceOutput( - services.Where(x => x!.Lifetime == 2 && x.Key is null).Select((x, _) => new KeyedService(x!.Type, x.Key!)).Collect().Combine(options), + services.Where(x => x!.Lifetime == 2 && x.Key is null).Select((x, _) => new KeyedService(x!.Type, null)).Collect().Combine(options), (ctx, data) => AddPartial("AddTransient", ctx, data)); context.RegisterImplementationSourceOutput( @@ -152,18 +202,55 @@ bool IsExport(AttributeData attr) (ctx, data) => AddPartial("AddKeyedTransient", ctx, data)); } + static ServiceRegistration? GetServiceRegistration(InvocationExpressionSyntax invocation, SemanticModel semanticModel) + { + var symbolInfo = semanticModel.GetSymbolInfo(invocation); + if (symbolInfo.Symbol is IMethodSymbol methodSymbol && + methodSymbol.GetAttributes().Any(attr => attr.AttributeClass?.Name == "DDIAddServicesAttribute") && + methodSymbol.Parameters.Length >= 2) + { + var defaultLifetime = methodSymbol.Parameters.FirstOrDefault(x => x.Type.Name == "ServiceLifetime" && x.HasExplicitDefaultValue)?.ExplicitDefaultValue; + // This allows us to change the API-provided default without having to change the source generator to match, if needed. + var lifetime = defaultLifetime is int value ? value : 0; + INamedTypeSymbol? assignableTo = null; + string? fullNameExpression = null; + + foreach (var argument in invocation.ArgumentList.Arguments) + { + var typeInfo = semanticModel.GetTypeInfo(argument.Expression).Type; + + if (typeInfo is INamedTypeSymbol namedType) + { + if (namedType.Name == "ServiceLifetime") + { + lifetime = (int?)semanticModel.GetConstantValue(argument.Expression).Value ?? 0; + } + else if (namedType.Name == "Type" && argument.Expression is TypeOfExpressionSyntax typeOf && + semanticModel.GetSymbolInfo(typeOf.Type).Symbol is INamedTypeSymbol typeSymbol) + { + // TODO: analyzer error if argument is not typeof(T) + assignableTo = typeSymbol; + } + else if (namedType.SpecialType == SpecialType.System_String) + { + fullNameExpression = semanticModel.GetConstantValue(argument.Expression).Value as string; + } + } + } + + if (assignableTo != null || fullNameExpression != null) + { + return new ServiceRegistration(lifetime, assignableTo, fullNameExpression); + } + } + return null; + } + void AddPartial(string methodName, SourceProductionContext ctx, (ImmutableArray Types, (AnalyzerConfigOptionsProvider Config, Compilation Compilation) Options) data) { var builder = new StringBuilder() .AppendLine("// "); - var rootNs = data.Options.Config.GlobalOptions.TryGetValue("build_property.AddServicesNamespace", out var value) && !string.IsNullOrEmpty(value) - ? value - : "Microsoft.Extensions.DependencyInjection"; - - var className = data.Options.Config.GlobalOptions.TryGetValue("build_property.AddServicesClassName", out value) && !string.IsNullOrEmpty(value) ? - value : "AddServicesExtension"; - foreach (var alias in data.Options.Compilation.References.SelectMany(r => r.Properties.Aliases)) { builder.AppendLine($"extern alias {alias};"); @@ -171,12 +258,12 @@ void AddPartial(string methodName, SourceProductionContext ctx, (ImmutableArray< builder.AppendLine( $$""" - using Microsoft.Extensions.DependencyInjection; + using Microsoft.Extensions.DependencyInjection.Extensions; using System; - namespace {{rootNs}} + namespace Microsoft.Extensions.DependencyInjection { - static partial class {{className}} + static partial class AttributedServicesExtension { static partial void {{methodName}}Services(IServiceCollection services) { @@ -224,11 +311,11 @@ void AddServices(IEnumerable services, Compilation compilation return $"s.GetRequiredService<{p.Type.ToFullName(compilation)}>()"; })); - output.AppendLine($" services.{methodName}(s => new {impl}({args}));"); + output.AppendLine($" services.Try{methodName}(s => new {impl}({args}));"); } else { - output.AppendLine($" services.{methodName}(s => new {impl}());"); + output.AppendLine($" services.Try{methodName}(s => new {impl}());"); } output.AppendLine($" services.AddTransient>(s => s.GetRequiredService<{impl}>);"); @@ -371,7 +458,6 @@ bool IsFromKeyed(AttributeData attr) attr.ConstructorArguments.Length > 0 && attr.ConstructorArguments[0].Kind == TypedConstantKind.Primitive); } - class TypesVisitor : SymbolVisitor { Func isAccessible; diff --git a/src/DependencyInjection.Attributed/SymbolExtensions.cs b/src/DependencyInjection.Attributed/SymbolExtensions.cs index eabae64..de04aa6 100644 --- a/src/DependencyInjection.Attributed/SymbolExtensions.cs +++ b/src/DependencyInjection.Attributed/SymbolExtensions.cs @@ -88,7 +88,8 @@ public static bool Is(this ITypeSymbol? @this, ITypeSymbol? baseTypeOrInterface) if (baseTypeOrInterface is INamedTypeSymbol namedExpected && @this is INamedTypeSymbol namedActual && namedActual.IsGenericType && - namedActual.ConstructedFrom.Equals(namedExpected, SymbolEqualityComparer.Default)) + (namedActual.ConstructedFrom.Equals(namedExpected, SymbolEqualityComparer.Default) || + namedActual.ConstructedFrom.Equals(namedExpected.OriginalDefinition, SymbolEqualityComparer.Default))) return true; foreach (var iface in @this.AllInterfaces) diff --git a/src/Directory.props b/src/Directory.props index a164d6e..63073d1 100644 --- a/src/Directory.props +++ b/src/Directory.props @@ -1,9 +1,9 @@  + enable false https://github.com/devlooped/DependencyInjection.Attributed - $(RestoreSources);https://pkg.kzu.app/index.json;https://api.nuget.org/v3/index.json \ No newline at end of file diff --git a/src/Directory.targets b/src/Directory.targets new file mode 100644 index 0000000..68c106c --- /dev/null +++ b/src/Directory.targets @@ -0,0 +1,33 @@ + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/Samples/ConsoleApp/ConsoleApp.csproj b/src/Samples/ConsoleApp/ConsoleApp.csproj index 6d19233..1dc9831 100644 --- a/src/Samples/ConsoleApp/ConsoleApp.csproj +++ b/src/Samples/ConsoleApp/ConsoleApp.csproj @@ -1,4 +1,5 @@  + Exe @@ -19,4 +20,6 @@ + + diff --git a/src/Samples/Library1/Library1.csproj b/src/Samples/Library1/Library1.csproj index f69070f..d83ae9a 100644 --- a/src/Samples/Library1/Library1.csproj +++ b/src/Samples/Library1/Library1.csproj @@ -1,4 +1,5 @@ - + + netstandard2.0 @@ -18,4 +19,6 @@ + +