Skip to content

Commit 1aa340d

Browse files
authored
Add multiple registrations if service implements open generic interface more than once (#36)
1 parent 1afd16b commit 1aa340d

File tree

3 files changed

+74
-16
lines changed

3 files changed

+74
-16
lines changed

ServiceScan.SourceGenerator.Tests/AddServicesTests.cs

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,62 @@ public class MyStringService : IService<string> { }
187187
Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString());
188188
}
189189

190+
[Fact]
191+
public void AddServicesAssignableToOpenGenericInterface_WithMultipleInterfaces()
192+
{
193+
var attribute = $"[GenerateServiceRegistrations(AssignableTo = typeof(IService<>))]";
194+
195+
var compilation = CreateCompilation(
196+
Sources.MethodWithAttribute(attribute),
197+
"""
198+
namespace GeneratorTests;
199+
200+
public interface IService<T> { }
201+
public interface IOtherInterface { }
202+
public class MyIntAndStringService : IService<int>, IService<string>, IOtherInterface { }
203+
""");
204+
205+
var results = CSharpGeneratorDriver
206+
.Create(_generator)
207+
.RunGenerators(compilation)
208+
.GetRunResult();
209+
210+
var registrations = $"""
211+
return services
212+
.AddTransient<global::GeneratorTests.IService<int>, global::GeneratorTests.MyIntAndStringService>()
213+
.AddTransient<global::GeneratorTests.IService<string>, global::GeneratorTests.MyIntAndStringService>();
214+
""";
215+
Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString());
216+
}
217+
218+
[Fact]
219+
public void AddServicesAssignableToOpenGenericInterface_WithMultipleInterfaces_AsSelfAndAsImplementedInterfaces()
220+
{
221+
var attribute = $"[GenerateServiceRegistrations(AssignableTo = typeof(IService<>), AsSelf = true, AsImplementedInterfaces = true, Lifetime = ServiceLifetime.Singleton)]";
222+
223+
var compilation = CreateCompilation(
224+
Sources.MethodWithAttribute(attribute),
225+
"""
226+
namespace GeneratorTests;
227+
228+
public interface IService<T> { }
229+
public class MyIntAndStringService : IService<int>, IService<string> { }
230+
""");
231+
232+
var results = CSharpGeneratorDriver
233+
.Create(_generator)
234+
.RunGenerators(compilation)
235+
.GetRunResult();
236+
237+
var registrations = $"""
238+
return services
239+
.AddSingleton<global::GeneratorTests.MyIntAndStringService, global::GeneratorTests.MyIntAndStringService>()
240+
.AddSingleton<global::GeneratorTests.IService<int>>(s => s.GetRequiredService<global::GeneratorTests.MyIntAndStringService>())
241+
.AddSingleton<global::GeneratorTests.IService<string>>(s => s.GetRequiredService<global::GeneratorTests.MyIntAndStringService>());
242+
""";
243+
Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString());
244+
}
245+
190246
[Fact]
191247
public void AddServicesAssignableToClosedGenericInterface()
192248
{
@@ -808,7 +864,7 @@ public class InterfacelessService {}
808864
}
809865

810866
[Fact]
811-
public void AddServicesBothAsSelfAndAsImplementedInterfaces()
867+
public void AddServicesBothAsSelfAndAsImplementedInterfaces_InterfacesAreForwardedToSelfRegistration()
812868
{
813869
var attribute = """
814870
[GenerateServiceRegistrations(

ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace ServiceScan.SourceGenerator;
1010

1111
public partial class DependencyInjectionGenerator
1212
{
13-
private static IEnumerable<(INamedTypeSymbol Type, INamedTypeSymbol? MatchedAssignableType)> FilterTypes
13+
private static IEnumerable<(INamedTypeSymbol Type, INamedTypeSymbol[]? MatchedAssignableTypes)> FilterTypes
1414
(Compilation compilation, AttributeModel attribute, INamedTypeSymbol containingType)
1515
{
1616
var assemblies = GetAssembliesToScan(compilation, attribute, containingType);
@@ -72,37 +72,39 @@ public partial class DependencyInjectionGenerator
7272
if (excludeAssignableToType != null && IsAssignableTo(type, excludeAssignableToType, out _))
7373
continue;
7474

75-
INamedTypeSymbol matchedType = null;
76-
if (assignableToType != null && !IsAssignableTo(type, assignableToType, out matchedType))
75+
INamedTypeSymbol[] matchedTypes = null;
76+
if (assignableToType != null && !IsAssignableTo(type, assignableToType, out matchedTypes))
7777
continue;
7878

79-
yield return (type, matchedType);
79+
yield return (type, matchedTypes);
8080
}
8181
}
8282

83-
private static bool IsAssignableTo(INamedTypeSymbol type, INamedTypeSymbol assignableTo, out INamedTypeSymbol? matchedType)
83+
private static bool IsAssignableTo(INamedTypeSymbol type, INamedTypeSymbol assignableTo, out INamedTypeSymbol[]? matchedTypes)
8484
{
8585
if (SymbolEqualityComparer.Default.Equals(type, assignableTo))
8686
{
87-
matchedType = type;
87+
matchedTypes = [type];
8888
return true;
8989
}
9090

9191
if (assignableTo.IsGenericType && assignableTo.IsDefinition)
9292
{
9393
if (assignableTo.TypeKind == TypeKind.Interface)
9494
{
95-
var matchingInterface = type.AllInterfaces.FirstOrDefault(i => i.IsGenericType && SymbolEqualityComparer.Default.Equals(i.OriginalDefinition, assignableTo));
96-
matchedType = matchingInterface;
97-
return matchingInterface != null;
95+
matchedTypes = type.AllInterfaces
96+
.Where(i => i.IsGenericType && SymbolEqualityComparer.Default.Equals(i.OriginalDefinition, assignableTo))
97+
.ToArray();
98+
99+
return matchedTypes.Length > 0;
98100
}
99101

100102
var baseType = type.BaseType;
101103
while (baseType != null)
102104
{
103105
if (baseType.IsGenericType && SymbolEqualityComparer.Default.Equals(baseType.OriginalDefinition, assignableTo))
104106
{
105-
matchedType = baseType;
107+
matchedTypes = [baseType];
106108
return true;
107109
}
108110

@@ -113,7 +115,7 @@ private static bool IsAssignableTo(INamedTypeSymbol type, INamedTypeSymbol assig
113115
{
114116
if (assignableTo.TypeKind == TypeKind.Interface)
115117
{
116-
matchedType = assignableTo;
118+
matchedTypes = [assignableTo];
117119
return type.AllInterfaces.Contains(assignableTo, SymbolEqualityComparer.Default);
118120
}
119121

@@ -122,15 +124,15 @@ private static bool IsAssignableTo(INamedTypeSymbol type, INamedTypeSymbol assig
122124
{
123125
if (SymbolEqualityComparer.Default.Equals(baseType, assignableTo))
124126
{
125-
matchedType = baseType;
127+
matchedTypes = [baseType];
126128
return true;
127129
}
128130

129131
baseType = baseType.BaseType;
130132
}
131133
}
132134

133-
matchedType = null;
135+
matchedTypes = null;
134136
return false;
135137
}
136138

ServiceScan.SourceGenerator/DependencyInjectionGenerator.FindServicesToRegister.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ private static DiagnosticModel<MethodImplementationModel> FindServicesToRegister
3131
{
3232
bool typesFound = false;
3333

34-
foreach (var (implementationType, matchedType) in FilterTypes(compilation, attribute, containingType))
34+
foreach (var (implementationType, matchedTypes) in FilterTypes(compilation, attribute, containingType))
3535
{
3636
typesFound = true;
3737

@@ -46,7 +46,7 @@ private static DiagnosticModel<MethodImplementationModel> FindServicesToRegister
4646
(true, true) => new[] { implementationType }.Concat(GetSuitableInterfaces(implementationType)),
4747
(false, true) => GetSuitableInterfaces(implementationType),
4848
(true, false) => [implementationType],
49-
_ => [matchedType ?? implementationType]
49+
_ => matchedTypes ?? [implementationType]
5050
};
5151

5252
foreach (var serviceType in serviceTypes)

0 commit comments

Comments
 (0)