Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,57 @@ public partial void ProcessServices()
Assert.Equal(expected, results.GeneratedTrees[1].ToString());
}

[Fact]
public void UseInstanceCustomHandlerMethod_FromParentType()
{
var source = $$"""
using ServiceScan.SourceGenerator;

namespace GeneratorTests;

public abstract class AbstractServiceProcessor
{
protected void HandleType<T>() => System.Console.WriteLine(typeof(T).Name);
}

public partial class ServicesProcessor : AbstractServiceProcessor
{
[GenerateServiceRegistrations(AssignableTo = typeof(IService), CustomHandler = nameof(HandleType))]
public partial void ProcessServices();
}
""";

var services =
"""
namespace GeneratorTests;

public interface IService { }
public class MyService1 : IService { }
public class MyService2 : IService { }
""";

var compilation = CreateCompilation(source, services);

var results = CSharpGeneratorDriver
.Create(_generator)
.RunGenerators(compilation)
.GetRunResult();

var expected = $$"""
namespace GeneratorTests;

public partial class ServicesProcessor
{
public partial void ProcessServices()
{
HandleType<global::GeneratorTests.MyService1>();
HandleType<global::GeneratorTests.MyService2>();
}
}
""";
Assert.Equal(expected, results.GeneratedTrees[1].ToString());
}

[Fact]
public void UseStaticMethodFromMatchedClassAsCustomHandler_WithoutParameters()
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Linq;
using Microsoft.CodeAnalysis;
using ServiceScan.SourceGenerator.Extensions;
using ServiceScan.SourceGenerator.Model;
using static ServiceScan.SourceGenerator.DiagnosticDescriptors;

Expand All @@ -16,14 +17,12 @@ public partial class DependencyInjectionGenerator
if (!method.IsPartialDefinition)
return Diagnostic.Create(NotPartialDefinition, method.Locations[0]);

var attributeData = context.Attributes.Select(a => AttributeModel.Create(a, method)).ToArray();
var position = context.TargetNode.SpanStart;
var attributeData = context.Attributes.Select(a => AttributeModel.Create(a, method, context.SemanticModel)).ToArray();
var hasCustomHandlers = attributeData.Any(a => a.CustomHandler != null);

for (var i = 0; i < context.Attributes.Length; i++)
foreach (var attribute in attributeData)
{
var attribute = AttributeModel.Create(context.Attributes[i], method);
attributeData[i] = attribute;

if (!attribute.HasSearchCriteria)
return Diagnostic.Create(MissingSearchCriteria, attribute.Location);

Expand All @@ -35,8 +34,7 @@ public partial class DependencyInjectionGenerator

if (attribute.KeySelector != null)
{
var keySelectorMethod = method.ContainingType.GetMembers().OfType<IMethodSymbol>()
.FirstOrDefault(m => m.IsStatic && m.Name == attribute.KeySelector);
var keySelectorMethod = method.ContainingType.GetMethod(attribute.KeySelector, context.SemanticModel, position, isStatic: true);

if (keySelectorMethod is not null)
{
Expand All @@ -53,8 +51,7 @@ public partial class DependencyInjectionGenerator

if (attribute.CustomHandler != null)
{
var customHandlerMethod = method.ContainingType.GetMembers().OfType<IMethodSymbol>()
.FirstOrDefault(m => m.Name == attribute.CustomHandler);
var customHandlerMethod = method.ContainingType.GetMethod(attribute.CustomHandler, context.SemanticModel, position);

if (customHandlerMethod != null)
{
Expand All @@ -80,7 +77,7 @@ public partial class DependencyInjectionGenerator
}
}

if (attributeData[i].HasErrors)
if (attribute.HasErrors)
return null;
}

Expand Down
36 changes: 36 additions & 0 deletions ServiceScan.SourceGenerator/Extensions/TypeSymbolExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using System.Linq;
using Microsoft.CodeAnalysis;

namespace ServiceScan.SourceGenerator.Extensions;

internal static class TypeSymbolExtensions
{
/// <summary>
/// Retrieves a method symbol from the specified type by name, considering accessibility, static context, and
/// inheritance.
/// </summary>
/// <remarks>This method searches the specified type and its base types for a method with the given name
/// that matches the specified accessibility and static context. If no matching method is found, the method returns
/// <see langword="null"/>.</remarks>
public static IMethodSymbol? GetMethod(this ITypeSymbol type, string methodName, SemanticModel semanticModel, int position, bool? isStatic = null)
{
var currentType = type;

while (currentType != null)
{
var method = currentType.GetMembers()
.OfType<IMethodSymbol>()
.Where(m => m.Name == methodName
&& (isStatic == null || m.IsStatic == isStatic)
&& semanticModel.IsAccessible(position, m))
.FirstOrDefault();

if (method != null)
return method;

currentType = currentType.BaseType;
}

return null;
}
}
13 changes: 6 additions & 7 deletions ServiceScan.SourceGenerator/Model/AttributeModel.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Linq;
using Microsoft.CodeAnalysis;
using ServiceScan.SourceGenerator.Extensions;

namespace ServiceScan.SourceGenerator.Model;

Expand Down Expand Up @@ -31,8 +32,10 @@ record AttributeModel(
{
public bool HasSearchCriteria => TypeNameFilter != null || AssignableToTypeName != null || AttributeFilterTypeName != null;

public static AttributeModel Create(AttributeData attribute, IMethodSymbol method)
public static AttributeModel Create(AttributeData attribute, IMethodSymbol method, SemanticModel semanticModel)
{
var position = attribute.ApplicationSyntaxReference?.Span.Start ?? 0;

var assemblyType = attribute.NamedArguments.FirstOrDefault(a => a.Key == "FromAssemblyOf").Value.Value as INamedTypeSymbol;
var assemblyNameFilter = attribute.NamedArguments.FirstOrDefault(a => a.Key == "AssemblyNameFilter").Value.Value as string;
var assignableTo = attribute.NamedArguments.FirstOrDefault(a => a.Key == "AssignableTo").Value.Value as INamedTypeSymbol;
Expand All @@ -51,9 +54,7 @@ public static AttributeModel Create(AttributeData attribute, IMethodSymbol metho
KeySelectorType? keySelectorType = null;
if (keySelector != null)
{
var keySelectorMethod = method.ContainingType.GetMembers()
.OfType<IMethodSymbol>()
.FirstOrDefault(m => m.IsStatic && m.Name == keySelector);
var keySelectorMethod = method.ContainingType.GetMethod(keySelector, semanticModel, position, isStatic: true);

if (keySelectorMethod != null)
{
Expand All @@ -69,9 +70,7 @@ public static AttributeModel Create(AttributeData attribute, IMethodSymbol metho
var customHandlerGenericParameters = 0;
if (customHandler != null)
{
var customHandlerMethod = method.ContainingType.GetMembers()
.OfType<IMethodSymbol>()
.FirstOrDefault(m => m.Name == customHandler);
var customHandlerMethod = method.ContainingType.GetMethod(customHandler, semanticModel, position);

customHandlerType = customHandlerMethod != null ? Model.CustomHandlerType.Method : Model.CustomHandlerType.TypeMethod;
customHandlerGenericParameters = customHandlerMethod?.TypeParameters.Length ?? 0;
Expand Down
Loading