Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions InterfaceReference.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Operations;
using Monkeymoto.GeneratorUtils;
using System;
using System.Linq;
using System.Threading;

namespace Monkeymoto.NativeGenericDelegates
{
internal sealed class InterfaceReference : IEquatable<InterfaceReference>
{
private readonly int hashCode;

public InterfaceDescriptor Interface { get; }
public int InvocationArgumentCount { get; }
public bool IsInterfaceOrMethodOpenGeneric { get; }
public MethodDescriptor Method { get; }
public IInvocationOperation MethodInvocation { get; }

public static bool operator ==(InterfaceReference? left, InterfaceReference? right) =>
left?.Equals(right) ?? right is null;
public static bool operator !=(InterfaceReference? left, InterfaceReference? right) =>
!(left == right);

public static InterfaceReference? GetReference
(
GenericSymbolReference reference,
CancellationToken cancellationToken
)
{
if ((reference.Node.Parent?.Parent is not InvocationExpressionSyntax invocationExpression) ||
(reference.Symbol is not INamedTypeSymbol interfaceSymbol))
{
return null;
}
var semanticModel = reference.SemanticModel!;
var methodInvocation = semanticModel
.GetOperation(invocationExpression, cancellationToken) as IInvocationOperation;
if (methodInvocation is not null)
{
var isInterfaceOrMethodOpenGeneric = !reference.IsSyntaxReferenceClosedTypeOrMethod ||
methodInvocation.TargetMethod.TypeArguments.Any(static x => x is not INamedTypeSymbol);
return new InterfaceReference
(
interfaceSymbol,
methodInvocation,
isInterfaceOrMethodOpenGeneric
);
}
return null;
}

public static InterfaceReference? GetReference(IInvocationOperation invocation)
{
if ((invocation is null) || invocation.TargetMethod.IsGenericMethod ||
(invocation.TargetMethod.ContainingType is not INamedTypeSymbol interfaceSymbol) ||
interfaceSymbol.IsGenericType)
{
return null;
}
return new InterfaceReference(interfaceSymbol, invocation, false);
}

private InterfaceReference
(
INamedTypeSymbol interfaceSymbol,
IInvocationOperation methodInvocation,
bool isInterfaceOrMethodOpenGeneric
)
{
Interface = new InterfaceDescriptor(interfaceSymbol);
InvocationArgumentCount = methodInvocation.Arguments
.Where(static x => x.ArgumentKind == ArgumentKind.Explicit)
.Count();
IsInterfaceOrMethodOpenGeneric = isInterfaceOrMethodOpenGeneric;
Method = new MethodDescriptor(Interface, methodInvocation.TargetMethod);
MethodInvocation = methodInvocation;
hashCode = Hash.Combine(Interface, MethodInvocation.Syntax);
}

public override bool Equals(object? obj) => obj is InterfaceReference other && Equals(other);
public bool Equals(InterfaceReference? other) => (other is not null) && (Interface == other.Interface) &&
MethodInvocation.Syntax.IsEquivalentTo(other.MethodInvocation.Syntax);
public override int GetHashCode() => hashCode;
}
}
98 changes: 68 additions & 30 deletions InterfaceReferenceCollection.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Operations;
using Monkeymoto.GeneratorUtils;
using System;
using System.Collections;
Expand All @@ -12,10 +13,10 @@ namespace Monkeymoto.NativeGenericDelegates
{
internal readonly struct InterfaceReferenceCollection :
IEquatable<InterfaceReferenceCollection>,
IEnumerable<GenericSymbolReference>
IEnumerable<InterfaceReference>
{
private readonly int hashCode;
private readonly ImmutableHashSet<GenericSymbolReference> interfaceReferences;
private readonly ImmutableHashSet<InterfaceReference> interfaceReferences;
private readonly ImmutableHashSet<GenericSymbolReference> methodReferences;

public static bool operator ==(InterfaceReferenceCollection left, InterfaceReferenceCollection right) =>
Expand All @@ -29,38 +30,75 @@ public static IncrementalValueProvider<InterfaceReferenceCollection> GetReferenc
IncrementalValueProvider<InterfaceOrMethodSymbolCollection> symbolsProvider
)
{
var nonGenericInterfaceReferenceProvider = context.SyntaxProvider.CreateSyntaxProvider
(
(node, _) =>
{
if ((node is not MemberAccessExpressionSyntax memberAccessExpression) ||
(memberAccessExpression.Expression is not IdentifierNameSyntax identifierName) ||
(node.Parent is not InvocationExpressionSyntax))
{
return false;
}
string memberName = memberAccessExpression.Name.Identifier.ValueText;
string parentName = identifierName.Identifier.ValueText;
return ((memberName == "FromAction") || (memberName == "FromFunctionPointer")) &&
((parentName == "INativeAction") || (parentName == "IUnmanagedAction"));
},
(context, cancellationToken) => (IInvocationOperation)context.SemanticModel
.GetOperation(context.Node.Parent!, cancellationToken)!
).Collect();
var treeProvider = GenericSymbolReferenceTree.FromIncrementalGeneratorInitializationContext(context);
return symbolsProvider.Combine(treeProvider).Select(static (x, cancellationToken) =>
{
var symbols = x.Left;
using var tree = x.Right; // Dispose tree after we extract the symbol references we need
var interfaceReferences = ImmutableHashSet.CreateBuilder<GenericSymbolReference>();
var methodReferences = ImmutableHashSet.CreateBuilder<GenericSymbolReference>();
foreach (var symbol in symbols)
return symbolsProvider.Combine(nonGenericInterfaceReferenceProvider).Combine(treeProvider).Select
(
static (x, cancellationToken) =>
{
switch (symbol)
var (symbols, nonGenericInterfaceReferences) = x.Left;
using var tree = x.Right; // Dispose tree after we extract the symbol references we need
var interfaceReferences = ImmutableHashSet.CreateBuilder<InterfaceReference>();
var methodReferences = ImmutableHashSet.CreateBuilder<GenericSymbolReference>();
foreach (var symbol in symbols)
{
switch (symbol)
{
case INamedTypeSymbol { IsGenericType: true }:
interfaceReferences.UnionWith
(
tree.GetBranchesBySymbol(symbol, cancellationToken)
.Select(x => InterfaceReference.GetReference(x, cancellationToken))
.Where(static x => x is not null)!
);
break;
case INamedTypeSymbol { IsGenericType: false }:
break;
case IMethodSymbol methodSymbol:
methodReferences.UnionWith(tree.GetBranchesBySymbol(symbol, cancellationToken));
break;
default:
throw new UnreachableException();
}
}
foreach
(
var reference in nonGenericInterfaceReferences
.Select(static x => InterfaceReference.GetReference(x))
.Where(static x => x is not null)
)
{
case INamedTypeSymbol:
interfaceReferences.UnionWith(tree.GetBranchesBySymbol(symbol, cancellationToken));
break;
case IMethodSymbol methodSymbol:
methodReferences.UnionWith(tree.GetBranchesBySymbol(symbol, cancellationToken));
break;
default:
throw new UnreachableException();
_ = interfaceReferences.Add(reference!);
}
return new InterfaceReferenceCollection
(
interfaceReferences.ToImmutable(),
methodReferences.ToImmutable()
);
}
return new InterfaceReferenceCollection
(
interfaceReferences.ToImmutable(),
methodReferences.ToImmutable()
);
});
);
}

private InterfaceReferenceCollection
(
ImmutableHashSet<GenericSymbolReference> interfaceReferences,
ImmutableHashSet<InterfaceReference> interfaceReferences,
ImmutableHashSet<GenericSymbolReference> methodReferences
)
{
Expand All @@ -73,21 +111,21 @@ ImmutableHashSet<GenericSymbolReference> methodReferences
public bool Equals(InterfaceReferenceCollection other) =>
interfaceReferences.SetEquals(other.interfaceReferences) &&
methodReferences.SetEquals(other.methodReferences);
public IEnumerator<GenericSymbolReference> GetEnumerator() => interfaceReferences.GetEnumerator();
public IEnumerator<InterfaceReference> GetEnumerator() => interfaceReferences.GetEnumerator();
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
public override int GetHashCode() => hashCode;

public IReadOnlyCollection<GenericSymbolReference> GetGenericMethodReferences
(
IMethodSymbol methodSymbol,
InvocationExpressionSyntax invocationExpression
InterfaceReference interfaceReference
)
{
methodSymbol = methodSymbol.OriginalDefinition;
var methodSymbol = interfaceReference.MethodInvocation.TargetMethod.OriginalDefinition;
var node = interfaceReference.MethodInvocation.Syntax;
return methodReferences.Where
(
x => SymbolEqualityComparer.Default.Equals(x.Symbol.OriginalDefinition, methodSymbol) &&
x.Node.IsEquivalentTo(invocationExpression)
x.Node.IsEquivalentTo(node)
).ToImmutableList();
}
}
Expand Down
33 changes: 12 additions & 21 deletions MarshalInfo.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Operations;
using System;
using System.Collections.Generic;
Expand All @@ -25,17 +24,14 @@ internal sealed partial class MarshalInfo : IEquatable<MarshalInfo>

private static IFieldReferenceOperation? GetCallingConventionOperation
(
InvocationExpressionSyntax invocationExpression,
SemanticModel semanticModel,
IInvocationOperation invocation,
CancellationToken cancellationToken,
IPropertySymbol? property = null,
Compilation? compilation = null
)
{
var callingConventionArg = invocationExpression.ArgumentList.Arguments
.Select(x => semanticModel.GetOperation(x, cancellationToken) as IArgumentOperation)
.Where(static x => (x is not null) && (x.Parameter!.Name == "callingConvention"))
.FirstOrDefault();
var callingConventionArg = invocation.Arguments.Where(static x => x.Parameter?.Name == "callingConvention")
.SingleOrDefault();
return callingConventionArg is not null ?
callingConventionArg.Value as IFieldReferenceOperation :
GetFieldReferenceOperation(property, compilation, cancellationToken);
Expand All @@ -50,19 +46,19 @@ CancellationToken cancellationToken

public static MarshalInfo GetMarshalInfo
(
InterfaceReference interfaceReference,
INamedTypeSymbol? marshaller,
InterfaceDescriptor interfaceDescriptor,
MethodDescriptor methodDescriptor,
InvocationExpressionSyntax invocationExpression,
SemanticModel semanticModel,
CancellationToken cancellationToken
)
{
var interfaceDescriptor = interfaceReference.Interface;
var methodDescriptor = interfaceReference.Method;
var invocation = interfaceReference.MethodInvocation;
if (marshaller is null)
{
return new(methodDescriptor, invocationExpression, semanticModel, cancellationToken);
return new(methodDescriptor, invocation, cancellationToken);
}
var compilation = semanticModel.Compilation;
var compilation = invocation.SemanticModel!.Compilation;
var marshallerInterface = compilation.GetTypeByMetadataName(Constants.IMarshallerMetadataName)!;
var properties = marshaller.GetMembers()
.OfType<IPropertySymbol>()
Expand Down Expand Up @@ -102,8 +98,7 @@ CancellationToken cancellationToken
}
var callingConventionOp = GetCallingConventionOperation
(
invocationExpression,
semanticModel,
invocation,
cancellationToken,
callingConventionProperty,
compilation
Expand Down Expand Up @@ -182,17 +177,13 @@ private static CallingConvention GetUnsafeStaticCallingConvention(MethodDescript
private MarshalInfo
(
MethodDescriptor methodDescriptor,
InvocationExpressionSyntax invocationExpression,
SemanticModel semanticModel,
IInvocationOperation invocation,
CancellationToken cancellationToken
)
{
StaticCallingConvention = methodDescriptor.IsFromUnsafeFunctionPointer ?
GetUnsafeStaticCallingConvention(methodDescriptor) :
GetStaticCallingConvention
(
GetCallingConventionOperation(invocationExpression, semanticModel, cancellationToken)
);
GetStaticCallingConvention(GetCallingConventionOperation(invocation, cancellationToken));
hashCode = Hash.Combine
(
MarshallerType,
Expand Down
Loading