Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,66 @@ public partial void ProcessServices( global::Microsoft.Extensions.DependencyInje
Assert.Equal(expected, results.GeneratedTrees[1].ToString());
}

[Fact]
public void AddServicesWithDecorator()
{
var services = """
namespace GeneratorTests;

public interface ICommandHandler<T> { }
public class CommandHandlerDecorator<T>(ICommandHandler<T> inner) : ICommandHandler<T>;

public class SpecificHandler1 : ICommandHandler<string>;
public class SpecificHandler2 : ICommandHandler<long>;
""";

var source = """
using ServiceScan.SourceGenerator;
using Microsoft.Extensions.DependencyInjection;

namespace GeneratorTests;

public static partial class ServiceCollectionExtensions
{
[GenerateServiceRegistrations(AssignableTo = typeof(ICommandHandler<>), CustomHandler = nameof(AddDecoratedHandler))]
public static partial IServiceCollection AddHandlers(this IServiceCollection services);

private static void AddDecoratedHandler<THandler, TCommand>(this IServiceCollection services)
where THandler : class, ICommandHandler<TCommand>
{
// Add handler itself to DI
services.AddScoped<THandler>();

// Register decorated handler as ICommandHandler
services.AddScoped<ICommandHandler<TCommand>>(s => new CommandHandlerDecorator<TCommand>(s.GetRequiredService<THandler>()));
}
}
""";


var compilation = CreateCompilation(source, services);

var results = CSharpGeneratorDriver
.Create(_generator)
.RunGenerators(compilation)
.GetRunResult();

var expected = $$"""
namespace GeneratorTests;

public static partial class ServiceCollectionExtensions
{
public static partial global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddHandlers(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services)
{
AddDecoratedHandler<global::GeneratorTests.SpecificHandler1, string>(services);
AddDecoratedHandler<global::GeneratorTests.SpecificHandler2, long>(services);
return services;
}
}
""";
Assert.Equal(expected, results.GeneratedTrees[1].ToString());
}

private static Compilation CreateCompilation(params string[] source)
{
var path = Path.GetDirectoryName(typeof(object).Assembly.Location)!;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ public partial class DependencyInjectionGenerator
if (type.IsStatic && attribute.CustomHandlerType != CustomHandlerType.TypeMethod)
continue;

// Cannot use open generics with CustomHandler
if (type.IsGenericType && attribute.CustomHandler != null)
continue;

if (attributeFilterType != null)
{
if (!type.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, attributeFilterType)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ .. matchedType.TypeArguments.Select(a => a.ToDisplayString(SymbolDisplayFormat.F
attribute.Lifetime,
serviceTypeName,
implementationTypeName,
false,
true,
ResolveImplementation: false,
IsOpenGeneric: true,
attribute.KeySelector,
attribute.KeySelectorType);

Expand All @@ -105,7 +105,7 @@ .. matchedType.TypeArguments.Select(a => a.ToDisplayString(SymbolDisplayFormat.F
serviceType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat),
implementationType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat),
shouldResolve,
false,
IsOpenGeneric: false,
attribute.KeySelector,
attribute.KeySelectorType);

Expand Down
Loading