diff --git a/InterfaceReference.cs b/InterfaceReference.cs new file mode 100644 index 0000000..4547943 --- /dev/null +++ b/InterfaceReference.cs @@ -0,0 +1,87 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Operations; +using Monkeymoto.GeneratorUtils; +using System; +using System.Linq; +using System.Threading; + +namespace Monkeymoto.NativeGenericDelegates +{ + internal sealed class InterfaceReference : IEquatable + { + private readonly int hashCode; + + public InterfaceDescriptor Interface { get; } + public int InvocationArgumentCount { get; } + public bool IsInterfaceOrMethodOpenGeneric { get; } + public MethodDescriptor Method { get; } + public IInvocationOperation MethodInvocation { get; } + + public static bool operator ==(InterfaceReference? left, InterfaceReference? right) => + left?.Equals(right) ?? right is null; + public static bool operator !=(InterfaceReference? left, InterfaceReference? right) => + !(left == right); + + public static InterfaceReference? GetReference + ( + GenericSymbolReference reference, + CancellationToken cancellationToken + ) + { + if ((reference.Node.Parent?.Parent is not InvocationExpressionSyntax invocationExpression) || + (reference.Symbol is not INamedTypeSymbol interfaceSymbol)) + { + return null; + } + var semanticModel = reference.SemanticModel!; + var methodInvocation = semanticModel + .GetOperation(invocationExpression, cancellationToken) as IInvocationOperation; + if (methodInvocation is not null) + { + var isInterfaceOrMethodOpenGeneric = !reference.IsSyntaxReferenceClosedTypeOrMethod || + methodInvocation.TargetMethod.TypeArguments.Any(static x => x is not INamedTypeSymbol); + return new InterfaceReference + ( + interfaceSymbol, + methodInvocation, + isInterfaceOrMethodOpenGeneric + ); + } + return null; + } + + public static InterfaceReference? GetReference(IInvocationOperation invocation) + { + if ((invocation is null) || invocation.TargetMethod.IsGenericMethod || + (invocation.TargetMethod.ContainingType is not INamedTypeSymbol interfaceSymbol) || + interfaceSymbol.IsGenericType) + { + return null; + } + return new InterfaceReference(interfaceSymbol, invocation, false); + } + + private InterfaceReference + ( + INamedTypeSymbol interfaceSymbol, + IInvocationOperation methodInvocation, + bool isInterfaceOrMethodOpenGeneric + ) + { + Interface = new InterfaceDescriptor(interfaceSymbol); + InvocationArgumentCount = methodInvocation.Arguments + .Where(static x => x.ArgumentKind == ArgumentKind.Explicit) + .Count(); + IsInterfaceOrMethodOpenGeneric = isInterfaceOrMethodOpenGeneric; + Method = new MethodDescriptor(Interface, methodInvocation.TargetMethod); + MethodInvocation = methodInvocation; + hashCode = Hash.Combine(Interface, MethodInvocation.Syntax); + } + + public override bool Equals(object? obj) => obj is InterfaceReference other && Equals(other); + public bool Equals(InterfaceReference? other) => (other is not null) && (Interface == other.Interface) && + MethodInvocation.Syntax.IsEquivalentTo(other.MethodInvocation.Syntax); + public override int GetHashCode() => hashCode; + } +} diff --git a/InterfaceReferenceCollection.cs b/InterfaceReferenceCollection.cs index ecac6e3..3b1aefb 100644 --- a/InterfaceReferenceCollection.cs +++ b/InterfaceReferenceCollection.cs @@ -1,5 +1,6 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Operations; using Monkeymoto.GeneratorUtils; using System; using System.Collections; @@ -12,10 +13,10 @@ namespace Monkeymoto.NativeGenericDelegates { internal readonly struct InterfaceReferenceCollection : IEquatable, - IEnumerable + IEnumerable { private readonly int hashCode; - private readonly ImmutableHashSet interfaceReferences; + private readonly ImmutableHashSet interfaceReferences; private readonly ImmutableHashSet methodReferences; public static bool operator ==(InterfaceReferenceCollection left, InterfaceReferenceCollection right) => @@ -29,38 +30,75 @@ public static IncrementalValueProvider GetReferenc IncrementalValueProvider symbolsProvider ) { + var nonGenericInterfaceReferenceProvider = context.SyntaxProvider.CreateSyntaxProvider + ( + (node, _) => + { + if ((node is not MemberAccessExpressionSyntax memberAccessExpression) || + (memberAccessExpression.Expression is not IdentifierNameSyntax identifierName) || + (node.Parent is not InvocationExpressionSyntax)) + { + return false; + } + string memberName = memberAccessExpression.Name.Identifier.ValueText; + string parentName = identifierName.Identifier.ValueText; + return ((memberName == "FromAction") || (memberName == "FromFunctionPointer")) && + ((parentName == "INativeAction") || (parentName == "IUnmanagedAction")); + }, + (context, cancellationToken) => (IInvocationOperation)context.SemanticModel + .GetOperation(context.Node.Parent!, cancellationToken)! + ).Collect(); var treeProvider = GenericSymbolReferenceTree.FromIncrementalGeneratorInitializationContext(context); - return symbolsProvider.Combine(treeProvider).Select(static (x, cancellationToken) => - { - var symbols = x.Left; - using var tree = x.Right; // Dispose tree after we extract the symbol references we need - var interfaceReferences = ImmutableHashSet.CreateBuilder(); - var methodReferences = ImmutableHashSet.CreateBuilder(); - foreach (var symbol in symbols) + return symbolsProvider.Combine(nonGenericInterfaceReferenceProvider).Combine(treeProvider).Select + ( + static (x, cancellationToken) => { - switch (symbol) + var (symbols, nonGenericInterfaceReferences) = x.Left; + using var tree = x.Right; // Dispose tree after we extract the symbol references we need + var interfaceReferences = ImmutableHashSet.CreateBuilder(); + var methodReferences = ImmutableHashSet.CreateBuilder(); + foreach (var symbol in symbols) + { + switch (symbol) + { + case INamedTypeSymbol { IsGenericType: true }: + interfaceReferences.UnionWith + ( + tree.GetBranchesBySymbol(symbol, cancellationToken) + .Select(x => InterfaceReference.GetReference(x, cancellationToken)) + .Where(static x => x is not null)! + ); + break; + case INamedTypeSymbol { IsGenericType: false }: + break; + case IMethodSymbol methodSymbol: + methodReferences.UnionWith(tree.GetBranchesBySymbol(symbol, cancellationToken)); + break; + default: + throw new UnreachableException(); + } + } + foreach + ( + var reference in nonGenericInterfaceReferences + .Select(static x => InterfaceReference.GetReference(x)) + .Where(static x => x is not null) + ) { - case INamedTypeSymbol: - interfaceReferences.UnionWith(tree.GetBranchesBySymbol(symbol, cancellationToken)); - break; - case IMethodSymbol methodSymbol: - methodReferences.UnionWith(tree.GetBranchesBySymbol(symbol, cancellationToken)); - break; - default: - throw new UnreachableException(); + _ = interfaceReferences.Add(reference!); } + return new InterfaceReferenceCollection + ( + interfaceReferences.ToImmutable(), + methodReferences.ToImmutable() + ); } - return new InterfaceReferenceCollection - ( - interfaceReferences.ToImmutable(), - methodReferences.ToImmutable() - ); - }); + ); } private InterfaceReferenceCollection ( - ImmutableHashSet interfaceReferences, + ImmutableHashSet interfaceReferences, ImmutableHashSet methodReferences ) { @@ -73,21 +111,21 @@ ImmutableHashSet methodReferences public bool Equals(InterfaceReferenceCollection other) => interfaceReferences.SetEquals(other.interfaceReferences) && methodReferences.SetEquals(other.methodReferences); - public IEnumerator GetEnumerator() => interfaceReferences.GetEnumerator(); + public IEnumerator GetEnumerator() => interfaceReferences.GetEnumerator(); IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); public override int GetHashCode() => hashCode; public IReadOnlyCollection GetGenericMethodReferences ( - IMethodSymbol methodSymbol, - InvocationExpressionSyntax invocationExpression + InterfaceReference interfaceReference ) { - methodSymbol = methodSymbol.OriginalDefinition; + var methodSymbol = interfaceReference.MethodInvocation.TargetMethod.OriginalDefinition; + var node = interfaceReference.MethodInvocation.Syntax; return methodReferences.Where ( x => SymbolEqualityComparer.Default.Equals(x.Symbol.OriginalDefinition, methodSymbol) && - x.Node.IsEquivalentTo(invocationExpression) + x.Node.IsEquivalentTo(node) ).ToImmutableList(); } } diff --git a/MarshalInfo.cs b/MarshalInfo.cs index 93cc792..a28c134 100644 --- a/MarshalInfo.cs +++ b/MarshalInfo.cs @@ -1,5 +1,4 @@ using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Operations; using System; using System.Collections.Generic; @@ -25,17 +24,14 @@ internal sealed partial class MarshalInfo : IEquatable private static IFieldReferenceOperation? GetCallingConventionOperation ( - InvocationExpressionSyntax invocationExpression, - SemanticModel semanticModel, + IInvocationOperation invocation, CancellationToken cancellationToken, IPropertySymbol? property = null, Compilation? compilation = null ) { - var callingConventionArg = invocationExpression.ArgumentList.Arguments - .Select(x => semanticModel.GetOperation(x, cancellationToken) as IArgumentOperation) - .Where(static x => (x is not null) && (x.Parameter!.Name == "callingConvention")) - .FirstOrDefault(); + var callingConventionArg = invocation.Arguments.Where(static x => x.Parameter?.Name == "callingConvention") + .SingleOrDefault(); return callingConventionArg is not null ? callingConventionArg.Value as IFieldReferenceOperation : GetFieldReferenceOperation(property, compilation, cancellationToken); @@ -50,19 +46,19 @@ CancellationToken cancellationToken public static MarshalInfo GetMarshalInfo ( + InterfaceReference interfaceReference, INamedTypeSymbol? marshaller, - InterfaceDescriptor interfaceDescriptor, - MethodDescriptor methodDescriptor, - InvocationExpressionSyntax invocationExpression, - SemanticModel semanticModel, CancellationToken cancellationToken ) { + var interfaceDescriptor = interfaceReference.Interface; + var methodDescriptor = interfaceReference.Method; + var invocation = interfaceReference.MethodInvocation; if (marshaller is null) { - return new(methodDescriptor, invocationExpression, semanticModel, cancellationToken); + return new(methodDescriptor, invocation, cancellationToken); } - var compilation = semanticModel.Compilation; + var compilation = invocation.SemanticModel!.Compilation; var marshallerInterface = compilation.GetTypeByMetadataName(Constants.IMarshallerMetadataName)!; var properties = marshaller.GetMembers() .OfType() @@ -102,8 +98,7 @@ CancellationToken cancellationToken } var callingConventionOp = GetCallingConventionOperation ( - invocationExpression, - semanticModel, + invocation, cancellationToken, callingConventionProperty, compilation @@ -182,17 +177,13 @@ private static CallingConvention GetUnsafeStaticCallingConvention(MethodDescript private MarshalInfo ( MethodDescriptor methodDescriptor, - InvocationExpressionSyntax invocationExpression, - SemanticModel semanticModel, + IInvocationOperation invocation, CancellationToken cancellationToken ) { StaticCallingConvention = methodDescriptor.IsFromUnsafeFunctionPointer ? GetUnsafeStaticCallingConvention(methodDescriptor) : - GetStaticCallingConvention - ( - GetCallingConventionOperation(invocationExpression, semanticModel, cancellationToken) - ); + GetStaticCallingConvention(GetCallingConventionOperation(invocation, cancellationToken)); hashCode = Hash.Combine ( MarshallerType, diff --git a/MethodReference.cs b/MethodReference.cs index e9cbe6c..fecb1c6 100644 --- a/MethodReference.cs +++ b/MethodReference.cs @@ -1,11 +1,11 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.Operations; using Monkeymoto.GeneratorUtils; using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; +using System.Linq.Expressions; using System.Threading; namespace Monkeymoto.NativeGenericDelegates @@ -25,86 +25,58 @@ internal sealed class MethodReference : IEquatable left?.Equals(right) ?? right is null; public static bool operator !=(MethodReference? left, MethodReference? right) => !(left == right); + private static MethodReference GetReference + ( + InterfaceReference interfaceReference, + CancellationToken cancellationToken, + INamedTypeSymbol? marshaller = null + ) + { + var marshalInfo = MarshalInfo.GetMarshalInfo(interfaceReference, marshaller, cancellationToken); + return new MethodReference(interfaceReference, marshalInfo); + } + public static IReadOnlyList? GetReferences ( - GenericSymbolReference interfaceReference, - Func> - getGenericMethodReferences, + InterfaceReference interfaceReference, + Func> getGenericMethodReferences, CancellationToken cancellationToken ) { - var node = interfaceReference.Node; - var invocationExpression = node.Parent?.Parent as InvocationExpressionSyntax; - if (invocationExpression is null) - { - return null; - } - var semanticModel = interfaceReference.SemanticModel!; - var operation = semanticModel.GetOperation(invocationExpression, cancellationToken); - if (operation is not IInvocationOperation invocation) - { - return null; - } + var invocation = interfaceReference.MethodInvocation; + var invocationExpression = (InvocationExpressionSyntax)invocation.Syntax; var methodSymbol = invocation.TargetMethod; var marshallers = new HashSet(SymbolEqualityComparer.Default); - var isOpenGenericMethod = false; if (methodSymbol.IsGenericMethod) { - if (methodSymbol.TypeArguments.First() is INamedTypeSymbol namedMarshaller) + if (methodSymbol.TypeArguments[0] is INamedTypeSymbol namedMarshaller) { marshallers.Add(namedMarshaller); } else { - isOpenGenericMethod = true; foreach ( - var marshaller in getGenericMethodReferences(methodSymbol, invocationExpression) - .Select(static x => (INamedTypeSymbol)x.TypeArguments.First()) + var marshaller in getGenericMethodReferences(interfaceReference) + .Select(static x => (INamedTypeSymbol)x.TypeArguments[0]) ) { marshallers.Add(marshaller); } } } - var invocationArgumentCount = invocation.Arguments.Length - - invocation.Arguments.Where(static x => x.ArgumentKind != ArgumentKind.Explicit).Count(); - var interfaceSymbol = (INamedTypeSymbol)interfaceReference.Symbol; - var interfaceDescriptor = new InterfaceDescriptor(interfaceSymbol); - var methodDescriptor = new MethodDescriptor(interfaceDescriptor, methodSymbol!); + var interfaceDescriptor = interfaceReference.Interface; + var methodDescriptor = interfaceReference.Method; var methodReferences = ImmutableList.CreateBuilder(); - - MethodReference GetReference(INamedTypeSymbol? marshaller) - { - var marshalInfo = MarshalInfo.GetMarshalInfo - ( - marshaller, - interfaceDescriptor, - methodDescriptor, - invocationExpression, - semanticModel, - cancellationToken - ); - return new MethodReference - ( - interfaceDescriptor, - methodDescriptor, - invocationExpression, - marshalInfo, - !interfaceReference.IsSyntaxReferenceClosedTypeOrMethod || isOpenGenericMethod, - invocationArgumentCount - ); - } - if (marshallers.Count == 0) { - methodReferences.Add(GetReference(null)); + methodReferences.Add(GetReference(interfaceReference, cancellationToken)); } else { foreach (var marshaller in marshallers) { - methodReferences.Add(GetReference(marshaller)); + methodReferences.Add(GetReference(interfaceReference, cancellationToken, marshaller)); } } return methodReferences.ToImmutable(); @@ -112,20 +84,17 @@ MethodReference GetReference(INamedTypeSymbol? marshaller) private MethodReference ( - InterfaceDescriptor interfaceDescriptor, - MethodDescriptor methodDescriptor, - InvocationExpressionSyntax invocationExpression, - MarshalInfo marshalInfo, - bool isInterfaceOrMethodOpenGeneric, - int invocationArgumentCount + InterfaceReference interfaceReference, + MarshalInfo marshalInfo ) { - Interface = interfaceDescriptor; - InvocationArgumentCount = invocationArgumentCount; - IsInterfaceOrMethodOpenGeneric = isInterfaceOrMethodOpenGeneric; + var invocationExpression = (InvocationExpressionSyntax)interfaceReference.MethodInvocation.Syntax; + Interface = interfaceReference.Interface; + InvocationArgumentCount = interfaceReference.InvocationArgumentCount; + IsInterfaceOrMethodOpenGeneric = interfaceReference.IsInterfaceOrMethodOpenGeneric; Location = new InterceptedLocation(invocationExpression); MarshalInfo = marshalInfo; - Method = methodDescriptor; + Method = interfaceReference.Method; hashCode = Hash.Combine(Location, Method, InvocationArgumentCount, MarshalInfo); }