Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion ServiceScan.SourceGenerator.Tests/AddServicesTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,62 @@ public class MyStringService : IService<string> { }
Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString());
}

[Fact]
public void AddServicesAssignableToOpenGenericInterface_WithMultipleInterfaces()
{
var attribute = $"[GenerateServiceRegistrations(AssignableTo = typeof(IService<>))]";

var compilation = CreateCompilation(
Sources.MethodWithAttribute(attribute),
"""
namespace GeneratorTests;

public interface IService<T> { }
public interface IOtherInterface { }
public class MyIntAndStringService : IService<int>, IService<string>, IOtherInterface { }
""");

var results = CSharpGeneratorDriver
.Create(_generator)
.RunGenerators(compilation)
.GetRunResult();

var registrations = $"""
return services
.AddTransient<global::GeneratorTests.IService<int>, global::GeneratorTests.MyIntAndStringService>()
.AddTransient<global::GeneratorTests.IService<string>, global::GeneratorTests.MyIntAndStringService>();
""";
Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString());
}

[Fact]
public void AddServicesAssignableToOpenGenericInterface_WithMultipleInterfaces_AsSelfAndAsImplementedInterfaces()
{
var attribute = $"[GenerateServiceRegistrations(AssignableTo = typeof(IService<>), AsSelf = true, AsImplementedInterfaces = true, Lifetime = ServiceLifetime.Singleton)]";

var compilation = CreateCompilation(
Sources.MethodWithAttribute(attribute),
"""
namespace GeneratorTests;

public interface IService<T> { }
public class MyIntAndStringService : IService<int>, IService<string> { }
""");

var results = CSharpGeneratorDriver
.Create(_generator)
.RunGenerators(compilation)
.GetRunResult();

var registrations = $"""
return services
.AddSingleton<global::GeneratorTests.MyIntAndStringService, global::GeneratorTests.MyIntAndStringService>()
.AddSingleton<global::GeneratorTests.IService<int>>(s => s.GetRequiredService<global::GeneratorTests.MyIntAndStringService>())
.AddSingleton<global::GeneratorTests.IService<string>>(s => s.GetRequiredService<global::GeneratorTests.MyIntAndStringService>());
""";
Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString());
}

[Fact]
public void AddServicesAssignableToClosedGenericInterface()
{
Expand Down Expand Up @@ -808,7 +864,7 @@ public class InterfacelessService {}
}

[Fact]
public void AddServicesBothAsSelfAndAsImplementedInterfaces()
public void AddServicesBothAsSelfAndAsImplementedInterfaces_InterfacesAreForwardedToSelfRegistration()
{
var attribute = """
[GenerateServiceRegistrations(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace ServiceScan.SourceGenerator;

public partial class DependencyInjectionGenerator
{
private static IEnumerable<(INamedTypeSymbol Type, INamedTypeSymbol? MatchedAssignableType)> FilterTypes
private static IEnumerable<(INamedTypeSymbol Type, INamedTypeSymbol[]? MatchedAssignableTypes)> FilterTypes
(Compilation compilation, AttributeModel attribute, INamedTypeSymbol containingType)
{
var assemblies = GetAssembliesToScan(compilation, attribute, containingType);
Expand Down Expand Up @@ -72,37 +72,39 @@ public partial class DependencyInjectionGenerator
if (excludeAssignableToType != null && IsAssignableTo(type, excludeAssignableToType, out _))
continue;

INamedTypeSymbol matchedType = null;
if (assignableToType != null && !IsAssignableTo(type, assignableToType, out matchedType))
INamedTypeSymbol[] matchedTypes = null;
if (assignableToType != null && !IsAssignableTo(type, assignableToType, out matchedTypes))
continue;

yield return (type, matchedType);
yield return (type, matchedTypes);
}
}

private static bool IsAssignableTo(INamedTypeSymbol type, INamedTypeSymbol assignableTo, out INamedTypeSymbol? matchedType)
private static bool IsAssignableTo(INamedTypeSymbol type, INamedTypeSymbol assignableTo, out INamedTypeSymbol[]? matchedTypes)
{
if (SymbolEqualityComparer.Default.Equals(type, assignableTo))
{
matchedType = type;
matchedTypes = [type];
return true;
}

if (assignableTo.IsGenericType && assignableTo.IsDefinition)
{
if (assignableTo.TypeKind == TypeKind.Interface)
{
var matchingInterface = type.AllInterfaces.FirstOrDefault(i => i.IsGenericType && SymbolEqualityComparer.Default.Equals(i.OriginalDefinition, assignableTo));
matchedType = matchingInterface;
return matchingInterface != null;
matchedTypes = type.AllInterfaces
.Where(i => i.IsGenericType && SymbolEqualityComparer.Default.Equals(i.OriginalDefinition, assignableTo))
.ToArray();

return matchedTypes.Length > 0;
}

var baseType = type.BaseType;
while (baseType != null)
{
if (baseType.IsGenericType && SymbolEqualityComparer.Default.Equals(baseType.OriginalDefinition, assignableTo))
{
matchedType = baseType;
matchedTypes = [baseType];
return true;
}

Expand All @@ -113,7 +115,7 @@ private static bool IsAssignableTo(INamedTypeSymbol type, INamedTypeSymbol assig
{
if (assignableTo.TypeKind == TypeKind.Interface)
{
matchedType = assignableTo;
matchedTypes = [assignableTo];
return type.AllInterfaces.Contains(assignableTo, SymbolEqualityComparer.Default);
}

Expand All @@ -122,15 +124,15 @@ private static bool IsAssignableTo(INamedTypeSymbol type, INamedTypeSymbol assig
{
if (SymbolEqualityComparer.Default.Equals(baseType, assignableTo))
{
matchedType = baseType;
matchedTypes = [baseType];
return true;
}

baseType = baseType.BaseType;
}
}

matchedType = null;
matchedTypes = null;
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ private static DiagnosticModel<MethodImplementationModel> FindServicesToRegister
{
bool typesFound = false;

foreach (var (implementationType, matchedType) in FilterTypes(compilation, attribute, containingType))
foreach (var (implementationType, matchedTypes) in FilterTypes(compilation, attribute, containingType))
{
typesFound = true;

Expand All @@ -46,7 +46,7 @@ private static DiagnosticModel<MethodImplementationModel> FindServicesToRegister
(true, true) => new[] { implementationType }.Concat(GetSuitableInterfaces(implementationType)),
(false, true) => GetSuitableInterfaces(implementationType),
(true, false) => [implementationType],
_ => [matchedType ?? implementationType]
_ => matchedTypes ?? [implementationType]
};

foreach (var serviceType in serviceTypes)
Expand Down
Loading