Skip to content

Commit 421e178

Browse files
committed
Fix edge cases
1 parent 106b814 commit 421e178

File tree

5 files changed

+113
-25
lines changed

5 files changed

+113
-25
lines changed

ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,51 @@ public static partial void ProcessServices( string value, decimal number)
105105
Assert.Equal(expected, results.GeneratedTrees[1].ToString());
106106
}
107107

108+
[Fact]
109+
public void CustomHandler_NoTypesFound()
110+
{
111+
var source = $$"""
112+
using ServiceScan.SourceGenerator;
113+
114+
namespace GeneratorTests;
115+
116+
public static partial class ServicesExtensions
117+
{
118+
[GenerateServiceRegistrations(AssignableTo = typeof(IService), CustomHandler = nameof(HandleType))]
119+
public static partial void ProcessServices();
120+
121+
private static void HandleType<T>() => System.Console.WriteLine(typeof(T).Name);
122+
}
123+
""";
124+
125+
var services =
126+
"""
127+
namespace GeneratorTests;
128+
129+
public interface IService { }
130+
""";
131+
132+
var compilation = CreateCompilation(source, services);
133+
134+
var results = CSharpGeneratorDriver
135+
.Create(_generator)
136+
.RunGenerators(compilation)
137+
.GetRunResult();
138+
139+
var expected = $$"""
140+
namespace GeneratorTests;
141+
142+
public static partial class ServicesExtensions
143+
{
144+
public static partial void ProcessServices()
145+
{
146+
147+
}
148+
}
149+
""";
150+
Assert.Equal(expected, results.GeneratedTrees[1].ToString());
151+
}
152+
108153
[Fact]
109154
public void CustomHandlerExtensionMethod()
110155
{
@@ -986,6 +1031,57 @@ public static partial void AddHandlers()
9861031
Assert.Equal(expected, results.GeneratedTrees[1].ToString());
9871032
}
9881033

1034+
[Fact]
1035+
public void CustomHandler_HandlesRecursiveConstraints()
1036+
{
1037+
var source = """
1038+
using ServiceScan.SourceGenerator;
1039+
1040+
namespace GeneratorTests;
1041+
1042+
public static partial class ServicesExtensions
1043+
{
1044+
[GenerateServiceRegistrations(TypeNameFilter = "*Smth*", CustomHandler = nameof(HandleType))]
1045+
public static partial void ProcessServices();
1046+
1047+
private static void HandleType<X, Y>()
1048+
where X : ISmth<Y>
1049+
where Y : ISmth<X>
1050+
=> System.Console.WriteLine(typeof(X).Name);
1051+
}
1052+
""";
1053+
1054+
var services = """
1055+
namespace GeneratorTests;
1056+
1057+
interface ISmth<T>;
1058+
class SmthX: ISmth<SmthY>;
1059+
class SmthY: ISmth<SmthX>;
1060+
class SmthString: ISmth<string>;
1061+
""";
1062+
1063+
var compilation = CreateCompilation(source, services);
1064+
1065+
var results = CSharpGeneratorDriver
1066+
.Create(_generator)
1067+
.RunGenerators(compilation)
1068+
.GetRunResult();
1069+
1070+
var expected = """
1071+
namespace GeneratorTests;
1072+
1073+
public static partial class ServicesExtensions
1074+
{
1075+
public static partial void ProcessServices()
1076+
{
1077+
HandleType<global::GeneratorTests.SmthX>();
1078+
HandleType<global::GeneratorTests.SmthY>();
1079+
}
1080+
}
1081+
""";
1082+
Assert.Equal(expected, results.GeneratedTrees[1].ToString());
1083+
}
1084+
9891085
private static Compilation CreateCompilation(params string[] source)
9901086
{
9911087
var path = Path.GetDirectoryName(typeof(object).Assembly.Location)!;

ServiceScan.SourceGenerator.Tests/DiagnosticTests.cs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,11 @@ public static partial class ServicesExtensions
161161
Assert.Equal(results.Diagnostics.Single().Descriptor, DiagnosticDescriptors.NoMatchingTypesFound);
162162

163163
var expectedFile = """
164-
using Microsoft.Extensions.DependencyInjection;
165-
166164
namespace GeneratorTests;
167165
168166
public static partial class ServicesExtensions
169167
{
170-
public static partial IServiceCollection AddServices(this IServiceCollection services)
168+
public static partial global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddServices(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services)
171169
{
172170
return services;
173171
}
@@ -203,13 +201,11 @@ public static partial class ServicesExtensions
203201
Assert.Equal(results.Diagnostics.Single().Descriptor, DiagnosticDescriptors.NoMatchingTypesFound);
204202

205203
var expectedFile = """
206-
using Microsoft.Extensions.DependencyInjection;
207-
208204
namespace GeneratorTests;
209205
210206
public static partial class ServicesExtensions
211207
{
212-
public static partial void AddServices(this IServiceCollection services)
208+
public static partial void AddServices(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services)
213209
{
214210
215211
}

ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -225,11 +225,16 @@ private static bool SatisfiesGenericConstraints(INamedTypeSymbol type, IMethodSy
225225
// (Other type parameters could be checked recursively from the first type parameter)
226226
var typeParameter = customHandlerMethod.TypeParameters[0];
227227

228-
return SatisfiesGenericConstraints(type, typeParameter, customHandlerMethod);
228+
var visitedTypeParameters = new HashSet<ITypeParameterSymbol>(SymbolEqualityComparer.Default);
229+
return SatisfiesGenericConstraints(type, typeParameter, customHandlerMethod, visitedTypeParameters);
229230
}
230231

231-
private static bool SatisfiesGenericConstraints(INamedTypeSymbol type, ITypeParameterSymbol typeParameter, IMethodSymbol customHandlerMethod)
232+
private static bool SatisfiesGenericConstraints(INamedTypeSymbol type, ITypeParameterSymbol typeParameter, IMethodSymbol customHandlerMethod, HashSet<ITypeParameterSymbol> visitedTypeParameters)
232233
{
234+
// Prevent infinite recursion in circular constraint scenarios (e.g., X : ISmth<Y>, Y : ISmth<X>)
235+
if (!visitedTypeParameters.Add(typeParameter))
236+
return true;
237+
233238
// Check reference type constraint
234239
if (typeParameter.HasReferenceTypeConstraint && type.IsValueType)
235240
return false;
@@ -259,15 +264,15 @@ private static bool SatisfiesGenericConstraints(INamedTypeSymbol type, ITypePara
259264
{
260265
if (constraintType is INamedTypeSymbol namedConstraintType)
261266
{
262-
if (!SatisfiesConstraintType(type, namedConstraintType, customHandlerMethod))
267+
if (!SatisfiesConstraintType(type, namedConstraintType, customHandlerMethod, visitedTypeParameters))
263268
return false;
264269
}
265270
}
266271

267272
return true;
268273
}
269274

270-
private static bool SatisfiesConstraintType(INamedTypeSymbol candidateType, INamedTypeSymbol constraintType, IMethodSymbol customHandlerMethod)
275+
private static bool SatisfiesConstraintType(INamedTypeSymbol candidateType, INamedTypeSymbol constraintType, IMethodSymbol customHandlerMethod, HashSet<ITypeParameterSymbol> visitedTypeParameters)
271276
{
272277
var constraintHasTypeParameters = constraintType.TypeArguments.OfType<ITypeParameterSymbol>().Any();
273278

@@ -289,10 +294,10 @@ private static bool SatisfiesConstraintType(INamedTypeSymbol candidateType, INam
289294

290295
// Then we need to check if any matched interfaces (let's say MyHandlerImplementation implements ICommandHandler<string> and ICommandHandler<MySpecificCommand>)
291296
// have matching type parameters (e.g. string does not implement ISpecificCommand, but MySpecificCommand - does).
292-
return matchedTypes.Any(matchedType => MatchedTypeSatisfiesConstraints(constraintType, customHandlerMethod, matchedType));
297+
return matchedTypes.Any(matchedType => MatchedTypeSatisfiesConstraints(constraintType, customHandlerMethod, matchedType, visitedTypeParameters));
293298
}
294299

295-
static bool MatchedTypeSatisfiesConstraints(INamedTypeSymbol constraintType, IMethodSymbol customHandlerMethod, INamedTypeSymbol matchedType)
300+
static bool MatchedTypeSatisfiesConstraints(INamedTypeSymbol constraintType, IMethodSymbol customHandlerMethod, INamedTypeSymbol matchedType, HashSet<ITypeParameterSymbol> visitedTypeParameters)
296301
{
297302
if (constraintType.TypeArguments.Length != matchedType.TypeArguments.Length)
298303
return false;
@@ -304,7 +309,7 @@ static bool MatchedTypeSatisfiesConstraints(INamedTypeSymbol constraintType, IMe
304309

305310
if (constraintType.TypeArguments[i] is ITypeParameterSymbol typeParameter)
306311
{
307-
if (!SatisfiesGenericConstraints(candidateTypeArgument, typeParameter, customHandlerMethod))
312+
if (!SatisfiesGenericConstraints(candidateTypeArgument, typeParameter, customHandlerMethod, visitedTypeParameters))
308313
return false;
309314
}
310315
else

ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,6 @@ public partial class DependencyInjectionGenerator
6565

6666
if (!typesMatch)
6767
return Diagnostic.Create(CustomHandlerMethodHasIncorrectSignature, attribute.Location);
68-
69-
// If CustomHandler has more than 1 type parameters, we try to resolve them from
70-
// matched assignableTo type arguments.
71-
// e.g. ApplyConfiguration<T, TEntity>(ModelBuilder modelBuilder) where T : IEntityTypeConfiguration<TEntity>
72-
if (customHandlerMethod.TypeParameters.Length > 1
73-
&& customHandlerMethod.TypeParameters.Length != attribute.AssignableToTypeParametersCount + 1)
74-
{
75-
return Diagnostic.Create(CustomHandlerMethodHasIncorrectSignature, attribute.Location);
76-
}
7768
}
7869
}
7970

ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
4040
return;
4141

4242
var (method, registrations, customHandling) = src.Model;
43-
string source = customHandling.Count > 0
44-
? GenerateCustomHandlingSource(method, customHandling)
45-
: GenerateRegistrationsSource(method, registrations);
43+
string source = registrations.Count > 0
44+
? GenerateRegistrationsSource(method, registrations)
45+
: GenerateCustomHandlingSource(method, customHandling);
4646

4747
source = source.ReplaceLineEndings();
4848

0 commit comments

Comments
 (0)