Skip to content

Commit a79dd6b

Browse files
authored
Find CustomHandler or KeySelector methods if they are defined in base types (#42)
1 parent 9fadc54 commit a79dd6b

File tree

4 files changed

+100
-17
lines changed

4 files changed

+100
-17
lines changed

ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,57 @@ public partial void ProcessServices()
445445
Assert.Equal(expected, results.GeneratedTrees[1].ToString());
446446
}
447447

448+
[Fact]
449+
public void UseInstanceCustomHandlerMethod_FromParentType()
450+
{
451+
var source = $$"""
452+
using ServiceScan.SourceGenerator;
453+
454+
namespace GeneratorTests;
455+
456+
public abstract class AbstractServiceProcessor
457+
{
458+
protected void HandleType<T>() => System.Console.WriteLine(typeof(T).Name);
459+
}
460+
461+
public partial class ServicesProcessor : AbstractServiceProcessor
462+
{
463+
[GenerateServiceRegistrations(AssignableTo = typeof(IService), CustomHandler = nameof(HandleType))]
464+
public partial void ProcessServices();
465+
}
466+
""";
467+
468+
var services =
469+
"""
470+
namespace GeneratorTests;
471+
472+
public interface IService { }
473+
public class MyService1 : IService { }
474+
public class MyService2 : IService { }
475+
""";
476+
477+
var compilation = CreateCompilation(source, services);
478+
479+
var results = CSharpGeneratorDriver
480+
.Create(_generator)
481+
.RunGenerators(compilation)
482+
.GetRunResult();
483+
484+
var expected = $$"""
485+
namespace GeneratorTests;
486+
487+
public partial class ServicesProcessor
488+
{
489+
public partial void ProcessServices()
490+
{
491+
HandleType<global::GeneratorTests.MyService1>();
492+
HandleType<global::GeneratorTests.MyService2>();
493+
}
494+
}
495+
""";
496+
Assert.Equal(expected, results.GeneratedTrees[1].ToString());
497+
}
498+
448499
[Fact]
449500
public void UseStaticMethodFromMatchedClassAsCustomHandler_WithoutParameters()
450501
{

ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Linq;
33
using Microsoft.CodeAnalysis;
4+
using ServiceScan.SourceGenerator.Extensions;
45
using ServiceScan.SourceGenerator.Model;
56
using static ServiceScan.SourceGenerator.DiagnosticDescriptors;
67

@@ -16,14 +17,12 @@ public partial class DependencyInjectionGenerator
1617
if (!method.IsPartialDefinition)
1718
return Diagnostic.Create(NotPartialDefinition, method.Locations[0]);
1819

19-
var attributeData = context.Attributes.Select(a => AttributeModel.Create(a, method)).ToArray();
20+
var position = context.TargetNode.SpanStart;
21+
var attributeData = context.Attributes.Select(a => AttributeModel.Create(a, method, context.SemanticModel)).ToArray();
2022
var hasCustomHandlers = attributeData.Any(a => a.CustomHandler != null);
2123

22-
for (var i = 0; i < context.Attributes.Length; i++)
24+
foreach (var attribute in attributeData)
2325
{
24-
var attribute = AttributeModel.Create(context.Attributes[i], method);
25-
attributeData[i] = attribute;
26-
2726
if (!attribute.HasSearchCriteria)
2827
return Diagnostic.Create(MissingSearchCriteria, attribute.Location);
2928

@@ -35,8 +34,7 @@ public partial class DependencyInjectionGenerator
3534

3635
if (attribute.KeySelector != null)
3736
{
38-
var keySelectorMethod = method.ContainingType.GetMembers().OfType<IMethodSymbol>()
39-
.FirstOrDefault(m => m.IsStatic && m.Name == attribute.KeySelector);
37+
var keySelectorMethod = method.ContainingType.GetMethod(attribute.KeySelector, context.SemanticModel, position, isStatic: true);
4038

4139
if (keySelectorMethod is not null)
4240
{
@@ -53,8 +51,7 @@ public partial class DependencyInjectionGenerator
5351

5452
if (attribute.CustomHandler != null)
5553
{
56-
var customHandlerMethod = method.ContainingType.GetMembers().OfType<IMethodSymbol>()
57-
.FirstOrDefault(m => m.Name == attribute.CustomHandler);
54+
var customHandlerMethod = method.ContainingType.GetMethod(attribute.CustomHandler, context.SemanticModel, position);
5855

5956
if (customHandlerMethod != null)
6057
{
@@ -80,7 +77,7 @@ public partial class DependencyInjectionGenerator
8077
}
8178
}
8279

83-
if (attributeData[i].HasErrors)
80+
if (attribute.HasErrors)
8481
return null;
8582
}
8683

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using System.Linq;
2+
using Microsoft.CodeAnalysis;
3+
4+
namespace ServiceScan.SourceGenerator.Extensions;
5+
6+
internal static class TypeSymbolExtensions
7+
{
8+
/// <summary>
9+
/// Retrieves a method symbol from the specified type by name, considering accessibility, static context, and
10+
/// inheritance.
11+
/// </summary>
12+
/// <remarks>This method searches the specified type and its base types for a method with the given name
13+
/// that matches the specified accessibility and static context. If no matching method is found, the method returns
14+
/// <see langword="null"/>.</remarks>
15+
public static IMethodSymbol? GetMethod(this ITypeSymbol type, string methodName, SemanticModel semanticModel, int position, bool? isStatic = null)
16+
{
17+
var currentType = type;
18+
19+
while (currentType != null)
20+
{
21+
var method = currentType.GetMembers()
22+
.OfType<IMethodSymbol>()
23+
.Where(m => m.Name == methodName
24+
&& (isStatic == null || m.IsStatic == isStatic)
25+
&& semanticModel.IsAccessible(position, m))
26+
.FirstOrDefault();
27+
28+
if (method != null)
29+
return method;
30+
31+
currentType = currentType.BaseType;
32+
}
33+
34+
return null;
35+
}
36+
}

ServiceScan.SourceGenerator/Model/AttributeModel.cs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System.Linq;
22
using Microsoft.CodeAnalysis;
3+
using ServiceScan.SourceGenerator.Extensions;
34

45
namespace ServiceScan.SourceGenerator.Model;
56

@@ -31,8 +32,10 @@ record AttributeModel(
3132
{
3233
public bool HasSearchCriteria => TypeNameFilter != null || AssignableToTypeName != null || AttributeFilterTypeName != null;
3334

34-
public static AttributeModel Create(AttributeData attribute, IMethodSymbol method)
35+
public static AttributeModel Create(AttributeData attribute, IMethodSymbol method, SemanticModel semanticModel)
3536
{
37+
var position = attribute.ApplicationSyntaxReference?.Span.Start ?? 0;
38+
3639
var assemblyType = attribute.NamedArguments.FirstOrDefault(a => a.Key == "FromAssemblyOf").Value.Value as INamedTypeSymbol;
3740
var assemblyNameFilter = attribute.NamedArguments.FirstOrDefault(a => a.Key == "AssemblyNameFilter").Value.Value as string;
3841
var assignableTo = attribute.NamedArguments.FirstOrDefault(a => a.Key == "AssignableTo").Value.Value as INamedTypeSymbol;
@@ -51,9 +54,7 @@ public static AttributeModel Create(AttributeData attribute, IMethodSymbol metho
5154
KeySelectorType? keySelectorType = null;
5255
if (keySelector != null)
5356
{
54-
var keySelectorMethod = method.ContainingType.GetMembers()
55-
.OfType<IMethodSymbol>()
56-
.FirstOrDefault(m => m.IsStatic && m.Name == keySelector);
57+
var keySelectorMethod = method.ContainingType.GetMethod(keySelector, semanticModel, position, isStatic: true);
5758

5859
if (keySelectorMethod != null)
5960
{
@@ -69,9 +70,7 @@ public static AttributeModel Create(AttributeData attribute, IMethodSymbol metho
6970
var customHandlerGenericParameters = 0;
7071
if (customHandler != null)
7172
{
72-
var customHandlerMethod = method.ContainingType.GetMembers()
73-
.OfType<IMethodSymbol>()
74-
.FirstOrDefault(m => m.Name == customHandler);
73+
var customHandlerMethod = method.ContainingType.GetMethod(customHandler, semanticModel, position);
7574

7675
customHandlerType = customHandlerMethod != null ? Model.CustomHandlerType.Method : Model.CustomHandlerType.TypeMethod;
7776
customHandlerGenericParameters = customHandlerMethod?.TypeParameters.Length ?? 0;

0 commit comments

Comments
 (0)