diff --git a/ServiceScan.SourceGenerator.Tests/AddServicesTests.cs b/ServiceScan.SourceGenerator.Tests/AddServicesTests.cs index 8014d2d..a9872d8 100644 --- a/ServiceScan.SourceGenerator.Tests/AddServicesTests.cs +++ b/ServiceScan.SourceGenerator.Tests/AddServicesTests.cs @@ -187,6 +187,62 @@ public class MyStringService : IService { } Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); } + [Fact] + public void AddServicesAssignableToOpenGenericInterface_WithMultipleInterfaces() + { + var attribute = $"[GenerateServiceRegistrations(AssignableTo = typeof(IService<>))]"; + + var compilation = CreateCompilation( + Sources.MethodWithAttribute(attribute), + """ + namespace GeneratorTests; + + public interface IService { } + public interface IOtherInterface { } + public class MyIntAndStringService : IService, IService, IOtherInterface { } + """); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + var registrations = $""" + return services + .AddTransient, global::GeneratorTests.MyIntAndStringService>() + .AddTransient, global::GeneratorTests.MyIntAndStringService>(); + """; + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + } + + [Fact] + public void AddServicesAssignableToOpenGenericInterface_WithMultipleInterfaces_AsSelfAndAsImplementedInterfaces() + { + var attribute = $"[GenerateServiceRegistrations(AssignableTo = typeof(IService<>), AsSelf = true, AsImplementedInterfaces = true, Lifetime = ServiceLifetime.Singleton)]"; + + var compilation = CreateCompilation( + Sources.MethodWithAttribute(attribute), + """ + namespace GeneratorTests; + + public interface IService { } + public class MyIntAndStringService : IService, IService { } + """); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + var registrations = $""" + return services + .AddSingleton() + .AddSingleton>(s => s.GetRequiredService()) + .AddSingleton>(s => s.GetRequiredService()); + """; + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + } + [Fact] public void AddServicesAssignableToClosedGenericInterface() { @@ -808,7 +864,7 @@ public class InterfacelessService {} } [Fact] - public void AddServicesBothAsSelfAndAsImplementedInterfaces() + public void AddServicesBothAsSelfAndAsImplementedInterfaces_InterfacesAreForwardedToSelfRegistration() { var attribute = """ [GenerateServiceRegistrations( diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs index f2a79a6..65960a0 100644 --- a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs +++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs @@ -10,7 +10,7 @@ namespace ServiceScan.SourceGenerator; public partial class DependencyInjectionGenerator { - private static IEnumerable<(INamedTypeSymbol Type, INamedTypeSymbol? MatchedAssignableType)> FilterTypes + private static IEnumerable<(INamedTypeSymbol Type, INamedTypeSymbol[]? MatchedAssignableTypes)> FilterTypes (Compilation compilation, AttributeModel attribute, INamedTypeSymbol containingType) { var assemblies = GetAssembliesToScan(compilation, attribute, containingType); @@ -72,19 +72,19 @@ public partial class DependencyInjectionGenerator if (excludeAssignableToType != null && IsAssignableTo(type, excludeAssignableToType, out _)) continue; - INamedTypeSymbol matchedType = null; - if (assignableToType != null && !IsAssignableTo(type, assignableToType, out matchedType)) + INamedTypeSymbol[] matchedTypes = null; + if (assignableToType != null && !IsAssignableTo(type, assignableToType, out matchedTypes)) continue; - yield return (type, matchedType); + yield return (type, matchedTypes); } } - private static bool IsAssignableTo(INamedTypeSymbol type, INamedTypeSymbol assignableTo, out INamedTypeSymbol? matchedType) + private static bool IsAssignableTo(INamedTypeSymbol type, INamedTypeSymbol assignableTo, out INamedTypeSymbol[]? matchedTypes) { if (SymbolEqualityComparer.Default.Equals(type, assignableTo)) { - matchedType = type; + matchedTypes = [type]; return true; } @@ -92,9 +92,11 @@ private static bool IsAssignableTo(INamedTypeSymbol type, INamedTypeSymbol assig { if (assignableTo.TypeKind == TypeKind.Interface) { - var matchingInterface = type.AllInterfaces.FirstOrDefault(i => i.IsGenericType && SymbolEqualityComparer.Default.Equals(i.OriginalDefinition, assignableTo)); - matchedType = matchingInterface; - return matchingInterface != null; + matchedTypes = type.AllInterfaces + .Where(i => i.IsGenericType && SymbolEqualityComparer.Default.Equals(i.OriginalDefinition, assignableTo)) + .ToArray(); + + return matchedTypes.Length > 0; } var baseType = type.BaseType; @@ -102,7 +104,7 @@ private static bool IsAssignableTo(INamedTypeSymbol type, INamedTypeSymbol assig { if (baseType.IsGenericType && SymbolEqualityComparer.Default.Equals(baseType.OriginalDefinition, assignableTo)) { - matchedType = baseType; + matchedTypes = [baseType]; return true; } @@ -113,7 +115,7 @@ private static bool IsAssignableTo(INamedTypeSymbol type, INamedTypeSymbol assig { if (assignableTo.TypeKind == TypeKind.Interface) { - matchedType = assignableTo; + matchedTypes = [assignableTo]; return type.AllInterfaces.Contains(assignableTo, SymbolEqualityComparer.Default); } @@ -122,7 +124,7 @@ private static bool IsAssignableTo(INamedTypeSymbol type, INamedTypeSymbol assig { if (SymbolEqualityComparer.Default.Equals(baseType, assignableTo)) { - matchedType = baseType; + matchedTypes = [baseType]; return true; } @@ -130,7 +132,7 @@ private static bool IsAssignableTo(INamedTypeSymbol type, INamedTypeSymbol assig } } - matchedType = null; + matchedTypes = null; return false; } diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FindServicesToRegister.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FindServicesToRegister.cs index b0eb81c..c47cfa0 100644 --- a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FindServicesToRegister.cs +++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FindServicesToRegister.cs @@ -31,7 +31,7 @@ private static DiagnosticModel FindServicesToRegister { bool typesFound = false; - foreach (var (implementationType, matchedType) in FilterTypes(compilation, attribute, containingType)) + foreach (var (implementationType, matchedTypes) in FilterTypes(compilation, attribute, containingType)) { typesFound = true; @@ -46,7 +46,7 @@ private static DiagnosticModel FindServicesToRegister (true, true) => new[] { implementationType }.Concat(GetSuitableInterfaces(implementationType)), (false, true) => GetSuitableInterfaces(implementationType), (true, false) => [implementationType], - _ => [matchedType ?? implementationType] + _ => matchedTypes ?? [implementationType] }; foreach (var serviceType in serviceTypes)