diff --git a/README.md b/README.md index af0b05f..695584d 100644 --- a/README.md +++ b/README.md @@ -133,6 +133,24 @@ public static partial class ServiceCollectionExtensions } ``` +### Apply EF Core IEntityTypeConfiguration types + +```csharp +public static partial class ModelBuilderExtensions +{ + [GenerateServiceRegistrations(AssignableTo = typeof(IEntityTypeConfiguration<>), CustomHandler = nameof(ApplyConfiguration))] + public static partial ModelBuilder ApplyEntityConfigurations(this ModelBuilder modelBuilder); + + private static void ApplyConfiguration(ModelBuilder modelBuilder) + where T : IEntityTypeConfiguration, new() + where TEntity : class + { + modelBuilder.ApplyConfiguration(new T()); + } +} +``` + + ## Parameters diff --git a/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs b/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs index 5f848a2..f3ddb9d 100644 --- a/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs +++ b/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs @@ -268,7 +268,6 @@ public static partial void ProcessServices() Assert.Equal(expected, results.GeneratedTrees[1].ToString()); } - [Fact] public void AddMultipleCustomHandlerAttributesWithSameCustomHandler() { @@ -319,6 +318,85 @@ public static partial void ProcessServices() Assert.Equal(expected, results.GeneratedTrees[1].ToString()); } + [Fact] + public void ResolveCustomHandlerGenericArguments() + { + var source = $$""" + using ServiceScan.SourceGenerator; + + namespace GeneratorTests; + + public static partial class ModelBuilderExtensions + { + [GenerateServiceRegistrations(AssignableTo = typeof(IEntityTypeConfiguration<>), CustomHandler = nameof(ApplyConfiguration))] + public static partial ModelBuilder ApplyEntityConfigurations(this ModelBuilder modelBuilder); + + private static void ApplyConfiguration(ModelBuilder modelBuilder) + where T : IEntityTypeConfiguration, new() + where TEntity : class + { + modelBuilder.ApplyConfiguration(new T()); + } + } + """; + + var infra = """ + public interface IEntityTypeConfiguration where TEntity : class + { + void Configure(EntityTypeBuilder builder); + } + + public class EntityTypeBuilder where TEntity : class; + + public class ModelBuilder + { + public ModelBuilder ApplyConfiguration(IEntityTypeConfiguration configuration) where TEntity : class + { + return this; + } + } + """; + + var configurations = """ + namespace GeneratorTests; + + public class EntityA; + public class EntityB; + + public class EntityAConfiguration : IEntityTypeConfiguration + { + public void Configure(EntityTypeBuilder builder) { } + } + + public class EntityBConfiguration : IEntityTypeConfiguration + { + public void Configure(EntityTypeBuilder builder) { } + } + """; + + var compilation = CreateCompilation(source, infra, configurations); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + var expected = $$""" + namespace GeneratorTests; + + public static partial class ModelBuilderExtensions + { + public static partial global::ModelBuilder ApplyEntityConfigurations(this global::ModelBuilder modelBuilder) + { + ApplyConfiguration(modelBuilder); + ApplyConfiguration(modelBuilder); + return modelBuilder; + } + } + """; + Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + } + private static Compilation CreateCompilation(params string[] source) { diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FindServicesToRegister.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FindServicesToRegister.cs index c47cfa0..8f37d5a 100644 --- a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FindServicesToRegister.cs +++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FindServicesToRegister.cs @@ -37,13 +37,31 @@ private static DiagnosticModel FindServicesToRegister if (attribute.CustomHandler != null) { - customHandlers.Add(new CustomHandlerModel(attribute.CustomHandler, implementationType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))); + // If CustomHandler method has multiple type parameters, which are resolvable from the first one - we try to provide them. + // e.g. ApplyConfiguration(ModelBuilder modelBuilder) where T : IEntityTypeConfiguration + if (attribute.CustomHandlerTypeParametersCount > 1 && matchedTypes != null) + { + foreach (var matchedType in matchedTypes) + { + EquatableArray typeArguments = + [ + implementationType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + .. matchedType.TypeArguments.Select(a => a.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)) + ]; + + customHandlers.Add(new CustomHandlerModel(attribute.CustomHandler, typeArguments)); + } + } + else + { + customHandlers.Add(new CustomHandlerModel(attribute.CustomHandler, [implementationType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)])); + } } else { var serviceTypes = (attribute.AsSelf, attribute.AsImplementedInterfaces) switch { - (true, true) => new[] { implementationType }.Concat(GetSuitableInterfaces(implementationType)), + (true, true) => [implementationType, .. GetSuitableInterfaces(implementationType)], (false, true) => GetSuitableInterfaces(implementationType), (true, false) => [implementationType], _ => matchedTypes ?? [implementationType] diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs index 263b2a4..dae59a4 100644 --- a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs +++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs @@ -69,6 +69,15 @@ public partial class DependencyInjectionGenerator if (!typesMatch) return Diagnostic.Create(CustomHandlerMethodHasIncorrectSignature, attribute.Location); + + // If CustomHandler has more than 1 type parameters, we try to resolve them from + // matched assignableTo type arguments. + // e.g. ApplyConfiguration(ModelBuilder modelBuilder) where T : IEntityTypeConfiguration + if (customHandlerMethod.TypeParameters.Length > 1 + && customHandlerMethod.TypeParameters.Length != attribute.AssignableToTypeParametersCount + 1) + { + return Diagnostic.Create(CustomHandlerMethodHasIncorrectSignature, attribute.Location); + } } if (attributeData[i].HasErrors) diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs index 0ac50ee..f966606 100644 --- a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs +++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs @@ -108,7 +108,11 @@ private static string GenerateRegistrationsSource(MethodModel method, EquatableA private static string GenerateCustomHandlingSource(MethodModel method, EquatableArray customHandlers) { var invocations = string.Join("\n", customHandlers.Select(h => - $" {h.HandlerMethodName}<{h.TypeName}>({string.Join(", ", method.Parameters.Select(p => p.Name))});")); + { + var genericArguments = string.Join(", ", h.TypeArguments); + var arguments = string.Join(", ", method.Parameters.Select(p => p.Name)); + return $" {h.HandlerMethodName}<{genericArguments}>({arguments});"; + })); var namespaceDeclaration = method.Namespace is null ? "" : $"namespace {method.Namespace};"; var parameters = string.Join(",", method.Parameters.Select((p, i) => diff --git a/ServiceScan.SourceGenerator/Model/AttributeModel.cs b/ServiceScan.SourceGenerator/Model/AttributeModel.cs index 61c2c40..1188374 100644 --- a/ServiceScan.SourceGenerator/Model/AttributeModel.cs +++ b/ServiceScan.SourceGenerator/Model/AttributeModel.cs @@ -7,6 +7,7 @@ enum KeySelectorType { Method, GenericMethod, TypeMember }; record AttributeModel( string? AssignableToTypeName, + int AssignableToTypeParametersCount, string? AssemblyNameFilter, EquatableArray? AssignableToGenericArguments, string? AssemblyOfTypeName, @@ -20,6 +21,7 @@ record AttributeModel( string? KeySelector, KeySelectorType? KeySelectorType, string? CustomHandler, + int CustomHandlerTypeParametersCount, bool AsImplementedInterfaces, bool AsSelf, Location Location, @@ -42,6 +44,8 @@ public static AttributeModel Create(AttributeData attribute, IMethodSymbol metho var keySelector = attribute.NamedArguments.FirstOrDefault(a => a.Key == "KeySelector").Value.Value as string; var customHandler = attribute.NamedArguments.FirstOrDefault(a => a.Key == "CustomHandler").Value.Value as string; + var assignableToTypeParametersCount = assignableTo?.TypeParameters.Length ?? 0; + KeySelectorType? keySelectorType = null; if (keySelector != null) { @@ -59,6 +63,16 @@ public static AttributeModel Create(AttributeData attribute, IMethodSymbol metho } } + var customHandlerGenericParameters = 0; + if (customHandler != null) + { + var customHandlerMethod = method.ContainingType.GetMembers() + .OfType() + .FirstOrDefault(m => m.IsStatic && m.Name == customHandler); + + customHandlerGenericParameters = customHandlerMethod?.TypeParameters.Length ?? 0; + } + if (string.IsNullOrWhiteSpace(typeNameFilter)) typeNameFilter = null; @@ -97,6 +111,7 @@ public static AttributeModel Create(AttributeData attribute, IMethodSymbol metho return new( assignableToTypeName, + assignableToTypeParametersCount, assemblyNameFilter, assignableToGenericArguments, assemblyOfTypeName, @@ -110,6 +125,7 @@ public static AttributeModel Create(AttributeData attribute, IMethodSymbol metho keySelector, keySelectorType, customHandler, + customHandlerGenericParameters, asImplementedInterfaces, asSelf, location, diff --git a/ServiceScan.SourceGenerator/Model/ServiceRegistrationModel.cs b/ServiceScan.SourceGenerator/Model/ServiceRegistrationModel.cs index ef98acf..09aa36e 100644 --- a/ServiceScan.SourceGenerator/Model/ServiceRegistrationModel.cs +++ b/ServiceScan.SourceGenerator/Model/ServiceRegistrationModel.cs @@ -11,4 +11,4 @@ record ServiceRegistrationModel( record CustomHandlerModel( string HandlerMethodName, - string TypeName); + EquatableArray TypeArguments); diff --git a/version.json b/version.json index 259a0f3..2edee38 100644 --- a/version.json +++ b/version.json @@ -1,6 +1,6 @@ { "$schema": "https://raw.githubusercontent.com/dotnet/Nerdbank.GitVersioning/main/src/NerdBank.GitVersioning/version.schema.json", - "version": "2.1", + "version": "2.2", "publicReleaseRefSpec": [ "^refs/heads/main", "^refs/heads/v\\d+(?:\\.\\d+)?$"