diff --git a/Generator.cs b/Generator.cs index 52b0ada..61a89e2 100644 --- a/Generator.cs +++ b/Generator.cs @@ -34,9 +34,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context) context.RegisterPostInitializationOutput(static context => context.AddSource(Constants.DeclarationsSourceFileName, PostInitialization.GetSourceText())); var interfaceOrMethodSymbols = InterfaceOrMethodSymbolCollection.GetSymbols(context.CompilationProvider); - var interfaceOrMethodReferences = - InterfaceOrMethodReferenceCollection.GetReferences(context, interfaceOrMethodSymbols); - var methodReferences = MethodReferenceCollection.GetReferences(interfaceOrMethodReferences); + var interfaceReferences = InterfaceReferenceCollection.GetReferences(context, interfaceOrMethodSymbols); + var methodReferences = MethodReferenceCollection.GetReferences(interfaceReferences); var implementationClasses = ImplementationClassCollection.GetImplementationClasses(methodReferences); context.RegisterImplementationSourceOutput ( diff --git a/InterfaceOrMethodReferenceCollection.cs b/InterfaceOrMethodReferenceCollection.cs deleted file mode 100644 index 286cf29..0000000 --- a/InterfaceOrMethodReferenceCollection.cs +++ /dev/null @@ -1,157 +0,0 @@ -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using Monkeymoto.GeneratorUtils; -using System; -using System.Collections; -using System.Collections.Generic; -using System.Collections.Immutable; -using System.Diagnostics; -using System.Linq; -using System.Threading; - -namespace Monkeymoto.NativeGenericDelegates -{ - internal readonly struct InterfaceOrMethodReferenceCollection : - IEquatable, - IEnumerable - { - private readonly int hashCode; - private readonly ImmutableHashSet references; - - public static bool operator == - ( - InterfaceOrMethodReferenceCollection left, - InterfaceOrMethodReferenceCollection right - ) => left.Equals(right); - - public static bool operator != - ( - InterfaceOrMethodReferenceCollection left, - InterfaceOrMethodReferenceCollection right - ) => !(left == right); - - private static void AddMethodReferences - ( - ImmutableHashSet.Builder references, - GenericSymbolReferenceTree tree, - IMethodSymbol methodSymbol, - CancellationToken cancellationToken - ) - { - var methodReferences = tree.GetBranchesBySymbol(methodSymbol, cancellationToken); - if (!methodReferences.Any()) - { - return; - } - var roots = tree.GetBranchesBySymbol(methodSymbol.ContainingType, cancellationToken); - if (!roots.Any()) - { - return; - } - var methodSymbolsToConstruct = new HashSet(SymbolEqualityComparer.Default); - foreach (var root in roots.Select(static x => (INamedTypeSymbol)x.Symbol)) - { - cancellationToken.ThrowIfCancellationRequested(); - var methodSymbolToConstruct = root.GetMembers(methodSymbol.Name) - .Where - ( - x => SymbolEqualityComparer.Default.Equals - ( - x.OriginalDefinition, - methodSymbol.OriginalDefinition - ) - ) - .Cast() - .Single(); - _ = methodSymbolsToConstruct.Add(methodSymbolToConstruct); - } - var constructedMethodSymbols = new HashSet(SymbolEqualityComparer.Default); - constructedMethodSymbols.UnionWith(methodReferences.SelectMany - ( - x => methodSymbolsToConstruct.Select - ( - y => y.Construct([.. ((IMethodSymbol)x.Symbol).TypeArguments]) - ) - )); - foreach (var methodReference in methodReferences) - { - cancellationToken.ThrowIfCancellationRequested(); - switch (methodReference.Node) - { - case IdentifierNameSyntax identifierName: - references.UnionWith - ( - constructedMethodSymbols.Select - ( - x => new GenericSymbolReference - ( - x, - methodReference.SemanticModel, - identifierName - ) - ) - ); - break; - case InvocationExpressionSyntax invocationExpression: - references.UnionWith - ( - constructedMethodSymbols.Select - ( - x => new GenericSymbolReference - ( - x, - methodReference.SemanticModel, - invocationExpression - ) - ) - ); - break; - default: - throw new UnreachableException(); - } - } - } - - public static IncrementalValueProvider GetReferences - ( - IncrementalGeneratorInitializationContext context, - IncrementalValueProvider symbolsProvider - ) - { - var treeProvider = GenericSymbolReferenceTree.FromIncrementalGeneratorInitializationContext(context); - return symbolsProvider.Combine(treeProvider).Select(static (x, cancellationToken) => - { - var symbols = x.Left; - using var tree = x.Right; // Dispose tree after we extract the symbol references we need - var references = ImmutableHashSet.CreateBuilder(); - foreach (var symbol in symbols) - { - switch (symbol) - { - case INamedTypeSymbol: - references.UnionWith(tree.GetBranchesBySymbol(symbol, cancellationToken)); - break; - case IMethodSymbol methodSymbol: - AddMethodReferences(references, tree, methodSymbol, cancellationToken); - break; - default: - throw new UnreachableException(); - } - } - return new InterfaceOrMethodReferenceCollection(references.ToImmutable()); - }); - } - - private InterfaceOrMethodReferenceCollection(ImmutableHashSet references) - { - this.references = references; - hashCode = Hash.Combine(references); - } - - public override bool Equals(object? obj) => obj is InterfaceOrMethodReferenceCollection other && Equals(other); - public bool Equals(InterfaceOrMethodReferenceCollection other) => references.SetEquals(other.references); - public IEnumerator GetEnumerator() => references.GetEnumerator(); - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - public override int GetHashCode() => hashCode; - } -} diff --git a/InterfaceReferenceCollection.cs b/InterfaceReferenceCollection.cs new file mode 100644 index 0000000..555bb9e --- /dev/null +++ b/InterfaceReferenceCollection.cs @@ -0,0 +1,94 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Monkeymoto.GeneratorUtils; +using System; +using System.Collections; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics; +using System.Linq; + +namespace Monkeymoto.NativeGenericDelegates +{ + internal readonly struct InterfaceReferenceCollection : + IEquatable, + IEnumerable + { + private readonly int hashCode; + private readonly ImmutableHashSet interfaceReferences; + private readonly ImmutableHashSet methodReferences; + + public static bool operator ==(InterfaceReferenceCollection left, InterfaceReferenceCollection right) => + left.Equals(right); + public static bool operator !=(InterfaceReferenceCollection left, InterfaceReferenceCollection right) => + !(left == right); + + public static IncrementalValueProvider GetReferences + ( + IncrementalGeneratorInitializationContext context, + IncrementalValueProvider symbolsProvider + ) + { + var treeProvider = GenericSymbolReferenceTree.FromIncrementalGeneratorInitializationContext(context); + return symbolsProvider.Combine(treeProvider).Select(static (x, cancellationToken) => + { + var symbols = x.Left; + using var tree = x.Right; // Dispose tree after we extract the symbol references we need + var interfaceReferences = ImmutableHashSet.CreateBuilder(); + var methodReferences = ImmutableHashSet.CreateBuilder(); + foreach (var symbol in symbols) + { + switch (symbol) + { + case INamedTypeSymbol: + interfaceReferences.UnionWith(tree.GetBranchesBySymbol(symbol, cancellationToken)); + break; + case IMethodSymbol methodSymbol: + methodReferences.UnionWith(tree.GetBranchesBySymbol(symbol, cancellationToken)); + break; + default: + throw new UnreachableException(); + } + } + return new InterfaceReferenceCollection + ( + interfaceReferences.ToImmutable(), + methodReferences.ToImmutable() + ); + }); + } + + private InterfaceReferenceCollection + ( + ImmutableHashSet interfaceReferences, + ImmutableHashSet methodReferences + ) + { + this.interfaceReferences = interfaceReferences; + this.methodReferences = methodReferences; + hashCode = Hash.Combine(interfaceReferences, methodReferences); + } + + public override bool Equals(object? obj) => obj is InterfaceReferenceCollection other && Equals(other); + public bool Equals(InterfaceReferenceCollection other) => + interfaceReferences.SetEquals(other.interfaceReferences) && + methodReferences.SetEquals(other.methodReferences); + public IEnumerator GetEnumerator() => interfaceReferences.GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + public override int GetHashCode() => hashCode; + + public IReadOnlyCollection GetGenericMethodReferences + ( + IMethodSymbol methodSymbol, + InvocationExpressionSyntax invocationExpression + ) + { + methodSymbol = methodSymbol.OriginalDefinition; + return methodReferences.Where + ( + x => SymbolEqualityComparer.Default.Equals(x.Symbol, methodSymbol) && + x.Node.IsEquivalentTo(invocationExpression) + ).ToImmutableList(); + } + } +} diff --git a/MarshalInfo.cs b/MarshalInfo.cs index 4881f3d..4853dec 100644 --- a/MarshalInfo.cs +++ b/MarshalInfo.cs @@ -49,14 +49,14 @@ CancellationToken cancellationToken public static MarshalInfo GetMarshalInfo ( - IMethodSymbol methodSymbol, + INamedTypeSymbol? marshaller, InterfaceDescriptor interfaceDescriptor, InvocationExpressionSyntax invocationExpression, SemanticModel semanticModel, CancellationToken cancellationToken ) { - if (methodSymbol.TypeArguments.FirstOrDefault() is not INamedTypeSymbol marshaller) + if (marshaller is null) { return new(invocationExpression, semanticModel, cancellationToken); } @@ -160,8 +160,6 @@ CancellationToken cancellationToken return null; } - private MarshalInfo() { } - private MarshalInfo ( InvocationExpressionSyntax invocationExpression, diff --git a/MethodReference.cs b/MethodReference.cs index d952a14..d12b97e 100644 --- a/MethodReference.cs +++ b/MethodReference.cs @@ -3,6 +3,8 @@ using Microsoft.CodeAnalysis.Operations; using Monkeymoto.GeneratorUtils; using System; +using System.Collections.Generic; +using System.Collections.Immutable; using System.Linq; using System.Threading; @@ -26,70 +28,86 @@ internal sealed class MethodReference : IEquatable left?.Equals(right) ?? right is null; public static bool operator !=(MethodReference? left, MethodReference? right) => !(left == right); - public static MethodReference? GetReference + public static IReadOnlyList? GetReferences ( - GenericSymbolReference interfaceOrMethodReference, + GenericSymbolReference interfaceReference, + Func> + getGenericMethodReferences, CancellationToken cancellationToken ) { - INamedTypeSymbol? interfaceSymbol; - IMethodSymbol? methodSymbol; - InvocationExpressionSyntax? invocationExpression; - var node = interfaceOrMethodReference.Node; - if (node is InvocationExpressionSyntax genericMethodInvocationExpression) + var node = interfaceReference.Node; + var invocationExpression = node.Parent?.Parent as InvocationExpressionSyntax; + if (invocationExpression is null) { - methodSymbol = (IMethodSymbol)interfaceOrMethodReference.Symbol; - interfaceSymbol = methodSymbol.ContainingType; - invocationExpression = genericMethodInvocationExpression; + return null; } - else if (node.Parent?.Parent is InvocationExpressionSyntax methodInvocationExpression) + var semanticModel = interfaceReference.SemanticModel!; + var operation = semanticModel.GetOperation(invocationExpression, cancellationToken); + if (operation is not IInvocationOperation invocation) { - var methodNameSyntax = ((MemberAccessExpressionSyntax)node.Parent).Name; - if (methodNameSyntax.Arity != 0) + return null; + } + var methodSymbol = invocation.TargetMethod; + var marshallers = new HashSet(SymbolEqualityComparer.Default); + if (methodSymbol.IsGenericMethod) + { + if (methodSymbol.TypeArguments.First() is INamedTypeSymbol namedMarshaller) { - return null; + marshallers.Add(namedMarshaller); + } + else + { + foreach + ( + var marshaller in getGenericMethodReferences(methodSymbol, invocationExpression) + .Select(static x => (INamedTypeSymbol)x.TypeArguments.First()) + ) + { + marshallers.Add(marshaller); + } } - interfaceSymbol = (INamedTypeSymbol)interfaceOrMethodReference.Symbol; - methodSymbol = interfaceSymbol.GetMembers(methodNameSyntax.Identifier.ToString()).Cast() - .First(x => !x.IsGenericMethod); - invocationExpression = methodInvocationExpression; } - else + var invocationArgumentCount = invocation.Arguments.Length - + invocation.Arguments.Where(static x => x.ArgumentKind != ArgumentKind.Explicit).Count(); + var interfaceSymbol = (INamedTypeSymbol)interfaceReference.Symbol; + var interfaceDescriptor = new InterfaceDescriptor(interfaceSymbol); + var methodDescriptor = new MethodDescriptor(interfaceDescriptor, methodSymbol!); + var methodReferences = ImmutableList.CreateBuilder(); + + MethodReference GetReference(INamedTypeSymbol? marshaller) { - return null; + var marshalInfo = MarshalInfo.GetMarshalInfo + ( + marshaller, + interfaceDescriptor, + invocationExpression, + semanticModel, + cancellationToken + ); + return new MethodReference + ( + interfaceDescriptor, + methodDescriptor, + invocationExpression, + marshalInfo, + !interfaceReference.IsSyntaxReferenceClosedTypeOrMethod, + invocationArgumentCount + ); } - var semanticModel = interfaceOrMethodReference.SemanticModel!; - int invocationArgumentCount = 0; - var invocation = - semanticModel.GetOperation(invocationExpression, cancellationToken) as IInvocationOperation; - if (invocation is not null) + + if (marshallers.Count == 0) { - invocationArgumentCount = invocation.Arguments.Length - - invocation.Arguments.Where(static x => x.ArgumentKind != ArgumentKind.Explicit).Count(); + methodReferences.Add(GetReference(null)); } - var interfaceDescriptor = new InterfaceDescriptor(interfaceSymbol); - var methodDescriptor = new MethodDescriptor - ( - interfaceDescriptor, - methodSymbol! - ); - var marshalInfo = MarshalInfo.GetMarshalInfo - ( - methodSymbol!, - interfaceDescriptor, - invocationExpression, - interfaceOrMethodReference.SemanticModel!, - cancellationToken - ); - return new MethodReference - ( - interfaceDescriptor, - methodDescriptor, - invocationExpression, - marshalInfo, - !interfaceOrMethodReference.IsSyntaxReferenceClosedTypeOrMethod, - invocationArgumentCount - ); + else + { + foreach (var marshaller in marshallers) + { + methodReferences.Add(GetReference(marshaller)); + } + } + return methodReferences.ToImmutable(); } private MethodReference diff --git a/MethodReferenceCollection.cs b/MethodReferenceCollection.cs index 76d1513..1e4ee51 100644 --- a/MethodReferenceCollection.cs +++ b/MethodReferenceCollection.cs @@ -20,16 +20,21 @@ namespace Monkeymoto.NativeGenericDelegates public static IncrementalValueProvider GetReferences ( - IncrementalValueProvider interfaceOrMethodReferencesProvider - ) => interfaceOrMethodReferencesProvider.Select(static (interfaceOrMethodReferences, cancellationToken) => + IncrementalValueProvider interfaceReferencesProvider + ) => interfaceReferencesProvider.Select(static (interfaceReferences, cancellationToken) => { var builder = ImmutableHashSet.CreateBuilder(); - foreach (var interfaceOrMethodReference in interfaceOrMethodReferences) + foreach (var interfaceReference in interfaceReferences) { - var methodReference = MethodReference.GetReference(interfaceOrMethodReference, cancellationToken); - if (methodReference is not null) + var methodReferences = MethodReference.GetReferences + ( + interfaceReference, + interfaceReferences.GetGenericMethodReferences, + cancellationToken + ); + if (methodReferences is not null) { - _ = builder.Add(methodReference); + builder.UnionWith(methodReferences); } } return new MethodReferenceCollection(builder.ToImmutable());