Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,4 +164,4 @@ public static partial class ModelBuilderExtensions
| **ExcludeByTypeName** | Sets this value to exclude types from being registered by their full name. You can use '*' wildcards. You can also use ',' to separate multiple filters. |
| **ExcludeByAttribute** | Excludes matching types by the specified attribute type being present. |
| **KeySelector** | Sets this property to add types as keyed services. This property should point to one of the following: <br>- The name of a static method in the current type with a string return type. The method should be either generic or have a single parameter of type `Type`. <br>- A constant field or static property in the implementation type. |
| **CustomHandler** | Sets this property to invoke a custom method for each type found instead of regular registration logic. This property should point to one of the following: <br>- Name of a generic method in the current type. <br>- Static method name in found types. <br>This property is incompatible with `Lifetime`, `AsImplementedInterfaces`, `AsSelf`, and `KeySelector` properties. |
| **CustomHandler** | Sets this property to invoke a custom method for each type found instead of regular registration logic. This property should point to one of the following: <br>- Name of a generic method in the current type. <br>- Static method name in found types. <br>This property is incompatible with `Lifetime`, `AsImplementedInterfaces`, `AsSelf`, and `KeySelector` properties. <br>**Note:** When using a generic `CustomHandler` method, types are automatically filtered by the generic constraints defined on the method's type parameters (e.g., `class`, `struct`, `new()`, interface constraints). |
323 changes: 323 additions & 0 deletions ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,329 @@ public static partial class ServiceCollectionExtensions
Assert.Equal(expected, results.GeneratedTrees[1].ToString());
}

[Fact]
public void CustomHandler_FiltersByNewConstraint()
{
var source = """
using ServiceScan.SourceGenerator;

namespace GeneratorTests;

public static partial class ServicesExtensions
{
[GenerateServiceRegistrations(AssignableTo = typeof(IService), CustomHandler = nameof(HandleType))]
public static partial void ProcessServices();

private static void HandleType<T>() where T : IService, new() => System.Console.WriteLine(typeof(T).Name);
}
""";

var services = """
namespace GeneratorTests;

public interface IService { }
public class ServiceWithParameterlessConstructor : IService { }
public class ServiceWithoutParameterlessConstructor : IService
{
public ServiceWithoutParameterlessConstructor(int value) { }
}
public class ServiceWithPrivateConstructor : IService
{
private ServiceWithPrivateConstructor() { }
}
""";

var compilation = CreateCompilation(source, services);

var results = CSharpGeneratorDriver
.Create(_generator)
.RunGenerators(compilation)
.GetRunResult();

var expected = """
namespace GeneratorTests;

public static partial class ServicesExtensions
{
public static partial void ProcessServices()
{
HandleType<global::GeneratorTests.ServiceWithParameterlessConstructor>();
}
}
""";
Assert.Equal(expected, results.GeneratedTrees[1].ToString());
}

[Fact]
public void CustomHandler_FiltersByClassConstraint()
{
var source = """
using ServiceScan.SourceGenerator;

namespace GeneratorTests;

public static partial class ServicesExtensions
{
[GenerateServiceRegistrations(TypeNameFilter = "*Service", CustomHandler = nameof(HandleType))]
public static partial void ProcessServices();

private static void HandleType<T>() where T : class => System.Console.WriteLine(typeof(T).Name);
}
""";

var services = """
namespace GeneratorTests;

public class ClassService { }
public struct StructService { }
""";

var compilation = CreateCompilation(source, services);

var results = CSharpGeneratorDriver
.Create(_generator)
.RunGenerators(compilation)
.GetRunResult();

var expected = """
namespace GeneratorTests;

public static partial class ServicesExtensions
{
public static partial void ProcessServices()
{
HandleType<global::GeneratorTests.ClassService>();
}
}
""";
Assert.Equal(expected, results.GeneratedTrees[1].ToString());
}

[Fact]
public void CustomHandler_FiltersByNestedTypeParameterConstraints()
{
var source = """
using ServiceScan.SourceGenerator;

namespace GeneratorTests;

public static partial class ServiceCollectionExtensions
{
[GenerateServiceRegistrations(AssignableTo = typeof(ICommandHandler<>), CustomHandler = nameof(AddHandler))]
public static partial void AddHandlers();

private static void AddHandler<THandler, TCommand>()
where THandler : class, ICommandHandler<TCommand>
where TCommand : class, ICommand
{
}
}
""";

var services = """
namespace GeneratorTests;

public interface ICommand { }
public interface ICommandHandler<T> where T : ICommand { }

public class ValidCommand : ICommand { }
public class InvalidCommand { }

public class ValidHandler : ICommandHandler<ValidCommand> { }
public class InvalidHandler : ICommandHandler<InvalidCommand> { }
""";

var compilation = CreateCompilation(source, services);

var results = CSharpGeneratorDriver
.Create(_generator)
.RunGenerators(compilation)
.GetRunResult();

var expected = """
namespace GeneratorTests;

public static partial class ServiceCollectionExtensions
{
public static partial void AddHandlers()
{
AddHandler<global::GeneratorTests.ValidHandler, global::GeneratorTests.ValidCommand>();
}
}
""";
Assert.Equal(expected, results.GeneratedTrees[1].ToString());
}

[Fact]
public void CustomHandler_FiltersByMultipleInterfacesWithDifferentTypeArguments()
{
var source = """
using ServiceScan.SourceGenerator;

namespace GeneratorTests;

public static partial class ServiceCollectionExtensions
{
[GenerateServiceRegistrations(AssignableTo = typeof(IHandler<>), CustomHandler = nameof(AddHandler))]
public static partial void AddHandlers();

private static void AddHandler<THandler, TArg>()
where THandler : class, IHandler<TArg>
where TArg : class
{
}
}
""";

var services = """
namespace GeneratorTests;

public interface IHandler<T> { }

public class Handler1 : IHandler<string> { }
public class Handler2 : IHandler<object> { }
public class Handler3 : IHandler<int> { }
public class MultiHandler : IHandler<string>, IHandler<object> { }
""";

var compilation = CreateCompilation(source, services);

var results = CSharpGeneratorDriver
.Create(_generator)
.RunGenerators(compilation)
.GetRunResult();

var expected = """
namespace GeneratorTests;

public static partial class ServiceCollectionExtensions
{
public static partial void AddHandlers()
{
AddHandler<global::GeneratorTests.Handler1, string>();
AddHandler<global::GeneratorTests.Handler2, object>();
AddHandler<global::GeneratorTests.MultiHandler, string>();
AddHandler<global::GeneratorTests.MultiHandler, object>();
}
}
""";
Assert.Equal(expected, results.GeneratedTrees[1].ToString());
}

[Fact]
public void CustomHandler_FiltersByValueTypeConstraint()
{
var source = """
using ServiceScan.SourceGenerator;

namespace GeneratorTests;

public static partial class ServiceCollectionExtensions
{
[GenerateServiceRegistrations(AssignableTo = typeof(IProcessor<>), CustomHandler = nameof(AddProcessor))]
public static partial void AddProcessors();

private static void AddProcessor<TProcessor, TValue>()
where TProcessor : class, IProcessor<TValue>
where TValue : struct
{
}
}
""";

var services = """
namespace GeneratorTests;

public interface IProcessor<T> { }

public class IntProcessor : IProcessor<int> { }
public class StringProcessor : IProcessor<string> { }
public class GuidProcessor : IProcessor<System.Guid> { }
""";

var compilation = CreateCompilation(source, services);

var results = CSharpGeneratorDriver
.Create(_generator)
.RunGenerators(compilation)
.GetRunResult();

var expected = """
namespace GeneratorTests;

public static partial class ServiceCollectionExtensions
{
public static partial void AddProcessors()
{
AddProcessor<global::GeneratorTests.IntProcessor, int>();
AddProcessor<global::GeneratorTests.GuidProcessor, global::System.Guid>();
}
}
""";
Assert.Equal(expected, results.GeneratedTrees[1].ToString());
}

[Fact]
public void CustomHandler_CombinedConstraints()
{
var source = """
using ServiceScan.SourceGenerator;

namespace GeneratorTests;

public interface IConfigurable { }

public static partial class ServiceCollectionExtensions
{
[GenerateServiceRegistrations(AssignableTo = typeof(IHandler<>), CustomHandler = nameof(AddHandler))]
public static partial void AddHandlers();

private static void AddHandler<THandler, TArg>()
where THandler : class, IHandler<TArg>, IConfigurable, new()
where TArg : class, new()
{
}
}
""";

var services = """
namespace GeneratorTests;

public interface IHandler<T> { }

public class Arg1 { }
public class Arg2 { public Arg2(int x) { } }

public class ValidHandler : IHandler<Arg1>, IConfigurable { }
public class HandlerWithoutConfigurable : IHandler<Arg1> { }
public class HandlerWithoutConstructor : IHandler<Arg1>, IConfigurable
{
public HandlerWithoutConstructor(int x) { }
}
public class HandlerWithNonConstructibleArg : IHandler<Arg2>, IConfigurable { }
""";

var compilation = CreateCompilation(source, services);

var results = CSharpGeneratorDriver
.Create(_generator)
.RunGenerators(compilation)
.GetRunResult();

var expected = """
namespace GeneratorTests;

public static partial class ServiceCollectionExtensions
{
public static partial void AddHandlers()
{
AddHandler<global::GeneratorTests.ValidHandler, global::GeneratorTests.Arg1>();
}
}
""";
Assert.Equal(expected, results.GeneratedTrees[1].ToString());
}

private static Compilation CreateCompilation(params string[] source)
{
var path = Path.GetDirectoryName(typeof(object).Assembly.Location)!;
Expand Down
Loading
Loading