diff --git a/README.md b/README.md index 3a0c62f..1ccfe65 100644 --- a/README.md +++ b/README.md @@ -164,4 +164,4 @@ public static partial class ModelBuilderExtensions | **ExcludeByTypeName** | Sets this value to exclude types from being registered by their full name. You can use '*' wildcards. You can also use ',' to separate multiple filters. | | **ExcludeByAttribute** | Excludes matching types by the specified attribute type being present. | | **KeySelector** | Sets this property to add types as keyed services. This property should point to one of the following:
- The name of a static method in the current type with a string return type. The method should be either generic or have a single parameter of type `Type`.
- A constant field or static property in the implementation type. | -| **CustomHandler** | Sets this property to invoke a custom method for each type found instead of regular registration logic. This property should point to one of the following:
- Name of a generic method in the current type.
- Static method name in found types.
This property is incompatible with `Lifetime`, `AsImplementedInterfaces`, `AsSelf`, and `KeySelector` properties. | \ No newline at end of file +| **CustomHandler** | Sets this property to invoke a custom method for each type found instead of regular registration logic. This property should point to one of the following:
- Name of a generic method in the current type.
- Static method name in found types.
This property is incompatible with `Lifetime`, `AsImplementedInterfaces`, `AsSelf`, and `KeySelector` properties.
**Note:** When using a generic `CustomHandler` method, types are automatically filtered by the generic constraints defined on the method's type parameters (e.g., `class`, `struct`, `new()`, interface constraints). | \ No newline at end of file diff --git a/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs b/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs index 01bdd1f..3fb9e01 100644 --- a/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs +++ b/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs @@ -105,6 +105,51 @@ public static partial void ProcessServices( string value, decimal number) Assert.Equal(expected, results.GeneratedTrees[1].ToString()); } + [Fact] + public void CustomHandler_NoTypesFound() + { + var source = $$""" + using ServiceScan.SourceGenerator; + + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + [GenerateServiceRegistrations(AssignableTo = typeof(IService), CustomHandler = nameof(HandleType))] + public static partial void ProcessServices(); + + private static void HandleType() => System.Console.WriteLine(typeof(T).Name); + } + """; + + var services = + """ + namespace GeneratorTests; + + public interface IService { } + """; + + var compilation = CreateCompilation(source, services); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + var expected = $$""" + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + public static partial void ProcessServices() + { + + } + } + """; + Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + } + [Fact] public void CustomHandlerExtensionMethod() { @@ -663,6 +708,380 @@ public static partial class ServiceCollectionExtensions Assert.Equal(expected, results.GeneratedTrees[1].ToString()); } + [Fact] + public void CustomHandler_FiltersByNewConstraint() + { + var source = """ + using ServiceScan.SourceGenerator; + + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + [GenerateServiceRegistrations(AssignableTo = typeof(IService), CustomHandler = nameof(HandleType))] + public static partial void ProcessServices(); + + private static void HandleType() where T : IService, new() => System.Console.WriteLine(typeof(T).Name); + } + """; + + var services = """ + namespace GeneratorTests; + + public interface IService { } + public class ServiceWithParameterlessConstructor : IService { } + public class ServiceWithoutParameterlessConstructor : IService + { + public ServiceWithoutParameterlessConstructor(int value) { } + } + public class ServiceWithPrivateConstructor : IService + { + private ServiceWithPrivateConstructor() { } + } + """; + + var compilation = CreateCompilation(source, services); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + var expected = """ + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + public static partial void ProcessServices() + { + HandleType(); + } + } + """; + Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + } + + [Fact] + public void CustomHandler_FiltersByClassConstraint() + { + var source = """ + using ServiceScan.SourceGenerator; + + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + [GenerateServiceRegistrations(TypeNameFilter = "*Service", CustomHandler = nameof(HandleType))] + public static partial void ProcessServices(); + + private static void HandleType() where T : class => System.Console.WriteLine(typeof(T).Name); + } + """; + + var services = """ + namespace GeneratorTests; + + public class ClassService { } + public struct StructService { } + """; + + var compilation = CreateCompilation(source, services); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + var expected = """ + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + public static partial void ProcessServices() + { + HandleType(); + } + } + """; + Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + } + + [Fact] + public void CustomHandler_FiltersByNestedTypeParameterConstraints() + { + var source = """ + using ServiceScan.SourceGenerator; + + namespace GeneratorTests; + + public static partial class ServiceCollectionExtensions + { + [GenerateServiceRegistrations(AssignableTo = typeof(ICommandHandler<>), CustomHandler = nameof(AddHandler))] + public static partial void AddHandlers(); + + private static void AddHandler() + where THandler : class, ICommandHandler + where TCommand : class, ICommand + { + } + } + """; + + var services = """ + namespace GeneratorTests; + + public interface ICommand { } + public interface ICommandHandler where T : ICommand { } + + public class ValidCommand : ICommand { } + public class InvalidCommand { } + + public class ValidHandler : ICommandHandler { } + public class InvalidHandler : ICommandHandler { } + """; + + var compilation = CreateCompilation(source, services); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + var expected = """ + namespace GeneratorTests; + + public static partial class ServiceCollectionExtensions + { + public static partial void AddHandlers() + { + AddHandler(); + } + } + """; + Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + } + + [Fact] + public void CustomHandler_FiltersByMultipleInterfacesWithDifferentTypeArguments() + { + var source = """ + using ServiceScan.SourceGenerator; + + namespace GeneratorTests; + + public static partial class ServiceCollectionExtensions + { + [GenerateServiceRegistrations(AssignableTo = typeof(IHandler<>), CustomHandler = nameof(AddHandler))] + public static partial void AddHandlers(); + + private static void AddHandler() + where THandler : class, IHandler + where TArg : class + { + } + } + """; + + var services = """ + namespace GeneratorTests; + + public interface IHandler { } + + public class Handler1 : IHandler { } + public class Handler2 : IHandler { } + public class Handler3 : IHandler { } + public class MultiHandler : IHandler, IHandler { } + """; + + var compilation = CreateCompilation(source, services); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + var expected = """ + namespace GeneratorTests; + + public static partial class ServiceCollectionExtensions + { + public static partial void AddHandlers() + { + AddHandler(); + AddHandler(); + AddHandler(); + AddHandler(); + } + } + """; + Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + } + + [Fact] + public void CustomHandler_FiltersByValueTypeConstraint() + { + var source = """ + using ServiceScan.SourceGenerator; + + namespace GeneratorTests; + + public static partial class ServiceCollectionExtensions + { + [GenerateServiceRegistrations(AssignableTo = typeof(IProcessor<>), CustomHandler = nameof(AddProcessor))] + public static partial void AddProcessors(); + + private static void AddProcessor() + where TProcessor : class, IProcessor + where TValue : struct + { + } + } + """; + + var services = """ + namespace GeneratorTests; + + public interface IProcessor { } + + public class IntProcessor : IProcessor { } + public class StringProcessor : IProcessor { } + public class GuidProcessor : IProcessor { } + """; + + var compilation = CreateCompilation(source, services); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + var expected = """ + namespace GeneratorTests; + + public static partial class ServiceCollectionExtensions + { + public static partial void AddProcessors() + { + AddProcessor(); + AddProcessor(); + } + } + """; + Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + } + + [Fact] + public void CustomHandler_CombinedConstraints() + { + var source = """ + using ServiceScan.SourceGenerator; + + namespace GeneratorTests; + + public interface IConfigurable { } + + public static partial class ServiceCollectionExtensions + { + [GenerateServiceRegistrations(AssignableTo = typeof(IHandler<>), CustomHandler = nameof(AddHandler))] + public static partial void AddHandlers(); + + private static void AddHandler() + where THandler : class, IHandler, IConfigurable, new() + where TArg : class, new() + { + } + } + """; + + var services = """ + namespace GeneratorTests; + + public interface IHandler { } + + public class Arg1 { } + public class Arg2 { public Arg2(int x) { } } + + public class ValidHandler : IHandler, IConfigurable { } + public class HandlerWithoutConfigurable : IHandler { } + public class HandlerWithoutConstructor : IHandler, IConfigurable + { + public HandlerWithoutConstructor(int x) { } + } + public class HandlerWithNonConstructibleArg : IHandler, IConfigurable { } + """; + + var compilation = CreateCompilation(source, services); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + var expected = """ + namespace GeneratorTests; + + public static partial class ServiceCollectionExtensions + { + public static partial void AddHandlers() + { + AddHandler(); + } + } + """; + Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + } + + [Fact] + public void CustomHandler_HandlesRecursiveConstraints() + { + var source = """ + using ServiceScan.SourceGenerator; + + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + [GenerateServiceRegistrations(TypeNameFilter = "*Smth*", CustomHandler = nameof(HandleType))] + public static partial void ProcessServices(); + + private static void HandleType() + where X : ISmth + where Y : ISmth + => System.Console.WriteLine(typeof(X).Name); + } + """; + + var services = """ + namespace GeneratorTests; + + interface ISmth; + class SmthX: ISmth; + class SmthY: ISmth; + class SmthString: ISmth; + """; + + var compilation = CreateCompilation(source, services); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + var expected = """ + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + public static partial void ProcessServices() + { + HandleType(); + HandleType(); + } + } + """; + Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + } + private static Compilation CreateCompilation(params string[] source) { var path = Path.GetDirectoryName(typeof(object).Assembly.Location)!; diff --git a/ServiceScan.SourceGenerator.Tests/DiagnosticTests.cs b/ServiceScan.SourceGenerator.Tests/DiagnosticTests.cs index 0531dd4..f01284d 100644 --- a/ServiceScan.SourceGenerator.Tests/DiagnosticTests.cs +++ b/ServiceScan.SourceGenerator.Tests/DiagnosticTests.cs @@ -161,13 +161,11 @@ public static partial class ServicesExtensions Assert.Equal(results.Diagnostics.Single().Descriptor, DiagnosticDescriptors.NoMatchingTypesFound); var expectedFile = """ - using Microsoft.Extensions.DependencyInjection; - namespace GeneratorTests; public static partial class ServicesExtensions { - public static partial IServiceCollection AddServices(this IServiceCollection services) + public static partial global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddServices(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services) { return services; } @@ -203,13 +201,11 @@ public static partial class ServicesExtensions Assert.Equal(results.Diagnostics.Single().Descriptor, DiagnosticDescriptors.NoMatchingTypesFound); var expectedFile = """ - using Microsoft.Extensions.DependencyInjection; - namespace GeneratorTests; public static partial class ServicesExtensions { - public static partial void AddServices(this IServiceCollection services) + public static partial void AddServices(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services) { } diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs index 7343bd6..678df2c 100644 --- a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs +++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs @@ -49,6 +49,10 @@ public partial class DependencyInjectionGenerator excludeAssignableToType = excludeAssignableToType.Construct(typeArguments); } + var customHandlerMethod = attribute.CustomHandler != null && attribute.CustomHandlerType == CustomHandlerType.Method + ? containingType.GetMembers().OfType().FirstOrDefault(m => m.Name == attribute.CustomHandler) + : null; + foreach (var type in assemblies.SelectMany(GetTypesFromAssembly)) { if (type.IsAbstract || !type.CanBeReferencedByName || type.TypeKind != TypeKind.Class) @@ -87,6 +91,12 @@ public partial class DependencyInjectionGenerator if (assignableToType != null && !IsAssignableTo(type, assignableToType, out matchedTypes)) continue; + // Filter by custom handler method generic constraints + if (customHandlerMethod != null && !SatisfiesGenericConstraints(type, customHandlerMethod)) + { + continue; + } + if (!semanticModel.IsAccessible(position, type)) continue; @@ -173,19 +183,6 @@ private static IEnumerable GetAssembliesToScan(Compilation comp return [containingType.ContainingAssembly]; } - private static IEnumerable GetSolutionAssemblies(Compilation compilation) - { - yield return compilation.Assembly; - - foreach (var reference in compilation.References) - { - if (reference is CompilationReference) - { - yield return (IAssemblySymbol)compilation.GetAssemblyOrModuleSymbol(reference); - } - } - } - private static IEnumerable GetTypesFromAssembly(IAssemblySymbol assemblySymbol) { var @namespace = assemblySymbol.GlobalNamespace; @@ -218,4 +215,111 @@ static IEnumerable GetTypesFromNamespaceOrType(INamespaceOrTyp ? null : new Regex($"^({Regex.Escape(wildcard).Replace(@"\*", ".*").Replace(",", "|")})$"); } + + private static bool SatisfiesGenericConstraints(INamedTypeSymbol type, IMethodSymbol customHandlerMethod) + { + if (customHandlerMethod.TypeParameters.Length == 0) + return true; + + // Check constraints on the first type parameter (which will be the implementation type) + // (Other type parameters could be checked recursively from the first type parameter) + var typeParameter = customHandlerMethod.TypeParameters[0]; + + var visitedTypeParameters = new HashSet(SymbolEqualityComparer.Default); + return SatisfiesGenericConstraints(type, typeParameter, customHandlerMethod, visitedTypeParameters); + } + + private static bool SatisfiesGenericConstraints(INamedTypeSymbol type, ITypeParameterSymbol typeParameter, IMethodSymbol customHandlerMethod, HashSet visitedTypeParameters) + { + // Prevent infinite recursion in circular constraint scenarios (e.g., X : ISmth, Y : ISmth) + if (!visitedTypeParameters.Add(typeParameter)) + return true; + + // Check reference type constraint + if (typeParameter.HasReferenceTypeConstraint && type.IsValueType) + return false; + + // Check value type constraint + if (typeParameter.HasValueTypeConstraint && !type.IsValueType) + return false; + + // Check unmanaged type constraint + if (typeParameter.HasUnmanagedTypeConstraint && !type.IsUnmanagedType) + return false; + + // Check constructor constraint + if (typeParameter.HasConstructorConstraint) + { + var hasPublicParameterlessConstructor = type.Constructors.Any(c => + c.DeclaredAccessibility == Accessibility.Public && + c.Parameters.Length == 0 && + !c.IsStatic); + + if (!hasPublicParameterlessConstructor) + return false; + } + + // Check type constraints + foreach (var constraintType in typeParameter.ConstraintTypes) + { + if (constraintType is INamedTypeSymbol namedConstraintType) + { + if (!SatisfiesConstraintType(type, namedConstraintType, customHandlerMethod, visitedTypeParameters)) + return false; + } + } + + return true; + } + + private static bool SatisfiesConstraintType(INamedTypeSymbol candidateType, INamedTypeSymbol constraintType, IMethodSymbol customHandlerMethod, HashSet visitedTypeParameters) + { + var constraintHasTypeParameters = constraintType.TypeArguments.OfType().Any(); + + if (!constraintHasTypeParameters) + { + return IsAssignableTo(candidateType, constraintType, out _); + } + else + { + // We handle the case when method has multiple type arguments, e.g. + // private static void CustomHandler(this IServiceCollection services) + // where THandler : class, ICommandHandler + // where TCommand : ISpecificCommand + + + // First we check that type definitions match. E.g. if MyHandlerImplementation has interface (one or many) ICommandHandler<>. + if (!IsAssignableTo(candidateType, constraintType.OriginalDefinition, out var matchedTypes)) + return false; + + // Then we need to check if any matched interfaces (let's say MyHandlerImplementation implements ICommandHandler and ICommandHandler) + // have matching type parameters (e.g. string does not implement ISpecificCommand, but MySpecificCommand - does). + return matchedTypes.Any(matchedType => MatchedTypeSatisfiesConstraints(constraintType, customHandlerMethod, matchedType, visitedTypeParameters)); + } + + static bool MatchedTypeSatisfiesConstraints(INamedTypeSymbol constraintType, IMethodSymbol customHandlerMethod, INamedTypeSymbol matchedType, HashSet visitedTypeParameters) + { + if (constraintType.TypeArguments.Length != matchedType.TypeArguments.Length) + return false; + + for (var i = 0; i < constraintType.TypeArguments.Length; i++) + { + if (matchedType.TypeArguments[i] is not INamedTypeSymbol candidateTypeArgument) + return false; + + if (constraintType.TypeArguments[i] is ITypeParameterSymbol typeParameter) + { + if (!SatisfiesGenericConstraints(candidateTypeArgument, typeParameter, customHandlerMethod, visitedTypeParameters)) + return false; + } + else + { + if (!SymbolEqualityComparer.Default.Equals(candidateTypeArgument, constraintType.TypeArguments[i])) + return false; + } + } + + return true; + } + } } diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs index b0b991a..27e774a 100644 --- a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs +++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs @@ -65,15 +65,6 @@ public partial class DependencyInjectionGenerator if (!typesMatch) return Diagnostic.Create(CustomHandlerMethodHasIncorrectSignature, attribute.Location); - - // If CustomHandler has more than 1 type parameters, we try to resolve them from - // matched assignableTo type arguments. - // e.g. ApplyConfiguration(ModelBuilder modelBuilder) where T : IEntityTypeConfiguration - if (customHandlerMethod.TypeParameters.Length > 1 - && customHandlerMethod.TypeParameters.Length != attribute.AssignableToTypeParametersCount + 1) - { - return Diagnostic.Create(CustomHandlerMethodHasIncorrectSignature, attribute.Location); - } } } diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs index 8d62c0f..f2970f4 100644 --- a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs +++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs @@ -40,9 +40,9 @@ public void Initialize(IncrementalGeneratorInitializationContext context) return; var (method, registrations, customHandling) = src.Model; - string source = customHandling.Count > 0 - ? GenerateCustomHandlingSource(method, customHandling) - : GenerateRegistrationsSource(method, registrations); + string source = registrations.Count > 0 + ? GenerateRegistrationsSource(method, registrations) + : GenerateCustomHandlingSource(method, customHandling); source = source.ReplaceLineEndings(); diff --git a/version.json b/version.json index 6ed9d64..c629a25 100644 --- a/version.json +++ b/version.json @@ -1,6 +1,6 @@ { "$schema": "https://raw.githubusercontent.com/dotnet/Nerdbank.GitVersioning/main/src/NerdBank.GitVersioning/version.schema.json", - "version": "2.3", + "version": "2.4", "publicReleaseRefSpec": [ "^refs/heads/main", "^refs/heads/v\\d+(?:\\.\\d+)?$"