diff --git a/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs b/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs index 81f70aa..e3aae21 100644 --- a/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs +++ b/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs @@ -445,6 +445,57 @@ public partial void ProcessServices() Assert.Equal(expected, results.GeneratedTrees[1].ToString()); } + [Fact] + public void UseInstanceCustomHandlerMethod_FromParentType() + { + var source = $$""" + using ServiceScan.SourceGenerator; + + namespace GeneratorTests; + + public abstract class AbstractServiceProcessor + { + protected void HandleType() => System.Console.WriteLine(typeof(T).Name); + } + + public partial class ServicesProcessor : AbstractServiceProcessor + { + [GenerateServiceRegistrations(AssignableTo = typeof(IService), CustomHandler = nameof(HandleType))] + public partial void ProcessServices(); + } + """; + + var services = + """ + namespace GeneratorTests; + + public interface IService { } + public class MyService1 : IService { } + public class MyService2 : IService { } + """; + + var compilation = CreateCompilation(source, services); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + var expected = $$""" + namespace GeneratorTests; + + public partial class ServicesProcessor + { + public partial void ProcessServices() + { + HandleType(); + HandleType(); + } + } + """; + Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + } + [Fact] public void UseStaticMethodFromMatchedClassAsCustomHandler_WithoutParameters() { diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs index 18b2628..b0b991a 100644 --- a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs +++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs @@ -1,6 +1,7 @@ using System; using System.Linq; using Microsoft.CodeAnalysis; +using ServiceScan.SourceGenerator.Extensions; using ServiceScan.SourceGenerator.Model; using static ServiceScan.SourceGenerator.DiagnosticDescriptors; @@ -16,14 +17,12 @@ public partial class DependencyInjectionGenerator if (!method.IsPartialDefinition) return Diagnostic.Create(NotPartialDefinition, method.Locations[0]); - var attributeData = context.Attributes.Select(a => AttributeModel.Create(a, method)).ToArray(); + var position = context.TargetNode.SpanStart; + var attributeData = context.Attributes.Select(a => AttributeModel.Create(a, method, context.SemanticModel)).ToArray(); var hasCustomHandlers = attributeData.Any(a => a.CustomHandler != null); - for (var i = 0; i < context.Attributes.Length; i++) + foreach (var attribute in attributeData) { - var attribute = AttributeModel.Create(context.Attributes[i], method); - attributeData[i] = attribute; - if (!attribute.HasSearchCriteria) return Diagnostic.Create(MissingSearchCriteria, attribute.Location); @@ -35,8 +34,7 @@ public partial class DependencyInjectionGenerator if (attribute.KeySelector != null) { - var keySelectorMethod = method.ContainingType.GetMembers().OfType() - .FirstOrDefault(m => m.IsStatic && m.Name == attribute.KeySelector); + var keySelectorMethod = method.ContainingType.GetMethod(attribute.KeySelector, context.SemanticModel, position, isStatic: true); if (keySelectorMethod is not null) { @@ -53,8 +51,7 @@ public partial class DependencyInjectionGenerator if (attribute.CustomHandler != null) { - var customHandlerMethod = method.ContainingType.GetMembers().OfType() - .FirstOrDefault(m => m.Name == attribute.CustomHandler); + var customHandlerMethod = method.ContainingType.GetMethod(attribute.CustomHandler, context.SemanticModel, position); if (customHandlerMethod != null) { @@ -80,7 +77,7 @@ public partial class DependencyInjectionGenerator } } - if (attributeData[i].HasErrors) + if (attribute.HasErrors) return null; } diff --git a/ServiceScan.SourceGenerator/Extensions/TypeSymbolExtensions.cs b/ServiceScan.SourceGenerator/Extensions/TypeSymbolExtensions.cs new file mode 100644 index 0000000..b1abcae --- /dev/null +++ b/ServiceScan.SourceGenerator/Extensions/TypeSymbolExtensions.cs @@ -0,0 +1,36 @@ +using System.Linq; +using Microsoft.CodeAnalysis; + +namespace ServiceScan.SourceGenerator.Extensions; + +internal static class TypeSymbolExtensions +{ + /// + /// Retrieves a method symbol from the specified type by name, considering accessibility, static context, and + /// inheritance. + /// + /// This method searches the specified type and its base types for a method with the given name + /// that matches the specified accessibility and static context. If no matching method is found, the method returns + /// . + public static IMethodSymbol? GetMethod(this ITypeSymbol type, string methodName, SemanticModel semanticModel, int position, bool? isStatic = null) + { + var currentType = type; + + while (currentType != null) + { + var method = currentType.GetMembers() + .OfType() + .Where(m => m.Name == methodName + && (isStatic == null || m.IsStatic == isStatic) + && semanticModel.IsAccessible(position, m)) + .FirstOrDefault(); + + if (method != null) + return method; + + currentType = currentType.BaseType; + } + + return null; + } +} diff --git a/ServiceScan.SourceGenerator/Model/AttributeModel.cs b/ServiceScan.SourceGenerator/Model/AttributeModel.cs index 5816ed8..60d8b83 100644 --- a/ServiceScan.SourceGenerator/Model/AttributeModel.cs +++ b/ServiceScan.SourceGenerator/Model/AttributeModel.cs @@ -1,5 +1,6 @@ using System.Linq; using Microsoft.CodeAnalysis; +using ServiceScan.SourceGenerator.Extensions; namespace ServiceScan.SourceGenerator.Model; @@ -31,8 +32,10 @@ record AttributeModel( { public bool HasSearchCriteria => TypeNameFilter != null || AssignableToTypeName != null || AttributeFilterTypeName != null; - public static AttributeModel Create(AttributeData attribute, IMethodSymbol method) + public static AttributeModel Create(AttributeData attribute, IMethodSymbol method, SemanticModel semanticModel) { + var position = attribute.ApplicationSyntaxReference?.Span.Start ?? 0; + var assemblyType = attribute.NamedArguments.FirstOrDefault(a => a.Key == "FromAssemblyOf").Value.Value as INamedTypeSymbol; var assemblyNameFilter = attribute.NamedArguments.FirstOrDefault(a => a.Key == "AssemblyNameFilter").Value.Value as string; var assignableTo = attribute.NamedArguments.FirstOrDefault(a => a.Key == "AssignableTo").Value.Value as INamedTypeSymbol; @@ -51,9 +54,7 @@ public static AttributeModel Create(AttributeData attribute, IMethodSymbol metho KeySelectorType? keySelectorType = null; if (keySelector != null) { - var keySelectorMethod = method.ContainingType.GetMembers() - .OfType() - .FirstOrDefault(m => m.IsStatic && m.Name == keySelector); + var keySelectorMethod = method.ContainingType.GetMethod(keySelector, semanticModel, position, isStatic: true); if (keySelectorMethod != null) { @@ -69,9 +70,7 @@ public static AttributeModel Create(AttributeData attribute, IMethodSymbol metho var customHandlerGenericParameters = 0; if (customHandler != null) { - var customHandlerMethod = method.ContainingType.GetMembers() - .OfType() - .FirstOrDefault(m => m.Name == customHandler); + var customHandlerMethod = method.ContainingType.GetMethod(customHandler, semanticModel, position); customHandlerType = customHandlerMethod != null ? Model.CustomHandlerType.Method : Model.CustomHandlerType.TypeMethod; customHandlerGenericParameters = customHandlerMethod?.TypeParameters.Length ?? 0;