diff --git a/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs b/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs index e3aae21..01bdd1f 100644 --- a/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs +++ b/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs @@ -603,6 +603,66 @@ public partial void ProcessServices( global::Microsoft.Extensions.DependencyInje Assert.Equal(expected, results.GeneratedTrees[1].ToString()); } + [Fact] + public void AddServicesWithDecorator() + { + var services = """ + namespace GeneratorTests; + + public interface ICommandHandler { } + public class CommandHandlerDecorator(ICommandHandler inner) : ICommandHandler; + + public class SpecificHandler1 : ICommandHandler; + public class SpecificHandler2 : ICommandHandler; + """; + + var source = """ + using ServiceScan.SourceGenerator; + using Microsoft.Extensions.DependencyInjection; + + namespace GeneratorTests; + + public static partial class ServiceCollectionExtensions + { + [GenerateServiceRegistrations(AssignableTo = typeof(ICommandHandler<>), CustomHandler = nameof(AddDecoratedHandler))] + public static partial IServiceCollection AddHandlers(this IServiceCollection services); + + private static void AddDecoratedHandler(this IServiceCollection services) + where THandler : class, ICommandHandler + { + // Add handler itself to DI + services.AddScoped(); + + // Register decorated handler as ICommandHandler + services.AddScoped>(s => new CommandHandlerDecorator(s.GetRequiredService())); + } + } + """; + + + var compilation = CreateCompilation(source, services); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + var expected = $$""" + namespace GeneratorTests; + + public static partial class ServiceCollectionExtensions + { + public static partial global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddHandlers(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services) + { + AddDecoratedHandler(services); + AddDecoratedHandler(services); + return services; + } + } + """; + Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + } + private static Compilation CreateCompilation(params string[] source) { var path = Path.GetDirectoryName(typeof(object).Assembly.Location)!; diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs index 33f8ac3..7343bd6 100644 --- a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs +++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs @@ -58,6 +58,10 @@ public partial class DependencyInjectionGenerator if (type.IsStatic && attribute.CustomHandlerType != CustomHandlerType.TypeMethod) continue; + // Cannot use open generics with CustomHandler + if (type.IsGenericType && attribute.CustomHandler != null) + continue; + if (attributeFilterType != null) { if (!type.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, attributeFilterType))) diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FindServicesToRegister.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FindServicesToRegister.cs index 3adfe1f..2e2da26 100644 --- a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FindServicesToRegister.cs +++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FindServicesToRegister.cs @@ -90,8 +90,8 @@ .. matchedType.TypeArguments.Select(a => a.ToDisplayString(SymbolDisplayFormat.F attribute.Lifetime, serviceTypeName, implementationTypeName, - false, - true, + ResolveImplementation: false, + IsOpenGeneric: true, attribute.KeySelector, attribute.KeySelectorType); @@ -105,7 +105,7 @@ .. matchedType.TypeArguments.Select(a => a.ToDisplayString(SymbolDisplayFormat.F serviceType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), implementationType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), shouldResolve, - false, + IsOpenGeneric: false, attribute.KeySelector, attribute.KeySelectorType);