Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
context.RegisterPostInitializationOutput(static context =>
context.AddSource(Constants.DeclarationsSourceFileName, PostInitialization.GetSourceText()));
var interfaceOrMethodSymbols = InterfaceOrMethodSymbolCollection.GetSymbols(context.CompilationProvider);
var interfaceOrMethodReferences =
InterfaceOrMethodReferenceCollection.GetReferences(context, interfaceOrMethodSymbols);
var methodReferences = MethodReferenceCollection.GetReferences(interfaceOrMethodReferences);
var interfaceReferences = InterfaceReferenceCollection.GetReferences(context, interfaceOrMethodSymbols);
var methodReferences = MethodReferenceCollection.GetReferences(interfaceReferences);
var implementationClasses = ImplementationClassCollection.GetImplementationClasses(methodReferences);
context.RegisterImplementationSourceOutput
(
Expand Down
157 changes: 0 additions & 157 deletions InterfaceOrMethodReferenceCollection.cs

This file was deleted.

94 changes: 94 additions & 0 deletions InterfaceReferenceCollection.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Monkeymoto.GeneratorUtils;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;

namespace Monkeymoto.NativeGenericDelegates
{
internal readonly struct InterfaceReferenceCollection :
IEquatable<InterfaceReferenceCollection>,
IEnumerable<GenericSymbolReference>
{
private readonly int hashCode;
private readonly ImmutableHashSet<GenericSymbolReference> interfaceReferences;
private readonly ImmutableHashSet<GenericSymbolReference> methodReferences;

public static bool operator ==(InterfaceReferenceCollection left, InterfaceReferenceCollection right) =>
left.Equals(right);
public static bool operator !=(InterfaceReferenceCollection left, InterfaceReferenceCollection right) =>
!(left == right);

public static IncrementalValueProvider<InterfaceReferenceCollection> GetReferences
(
IncrementalGeneratorInitializationContext context,
IncrementalValueProvider<InterfaceOrMethodSymbolCollection> symbolsProvider
)
{
var treeProvider = GenericSymbolReferenceTree.FromIncrementalGeneratorInitializationContext(context);
return symbolsProvider.Combine(treeProvider).Select(static (x, cancellationToken) =>
{
var symbols = x.Left;
using var tree = x.Right; // Dispose tree after we extract the symbol references we need
var interfaceReferences = ImmutableHashSet.CreateBuilder<GenericSymbolReference>();
var methodReferences = ImmutableHashSet.CreateBuilder<GenericSymbolReference>();
foreach (var symbol in symbols)
{
switch (symbol)
{
case INamedTypeSymbol:
interfaceReferences.UnionWith(tree.GetBranchesBySymbol(symbol, cancellationToken));
break;
case IMethodSymbol methodSymbol:
methodReferences.UnionWith(tree.GetBranchesBySymbol(symbol, cancellationToken));
break;
default:
throw new UnreachableException();
}
}
return new InterfaceReferenceCollection
(
interfaceReferences.ToImmutable(),
methodReferences.ToImmutable()
);
});
}

private InterfaceReferenceCollection
(
ImmutableHashSet<GenericSymbolReference> interfaceReferences,
ImmutableHashSet<GenericSymbolReference> methodReferences
)
{
this.interfaceReferences = interfaceReferences;
this.methodReferences = methodReferences;
hashCode = Hash.Combine(interfaceReferences, methodReferences);
}

public override bool Equals(object? obj) => obj is InterfaceReferenceCollection other && Equals(other);
public bool Equals(InterfaceReferenceCollection other) =>
interfaceReferences.SetEquals(other.interfaceReferences) &&
methodReferences.SetEquals(other.methodReferences);
public IEnumerator<GenericSymbolReference> GetEnumerator() => interfaceReferences.GetEnumerator();
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
public override int GetHashCode() => hashCode;

public IReadOnlyCollection<GenericSymbolReference> GetGenericMethodReferences
(
IMethodSymbol methodSymbol,
InvocationExpressionSyntax invocationExpression
)
{
methodSymbol = methodSymbol.OriginalDefinition;
return methodReferences.Where
(
x => SymbolEqualityComparer.Default.Equals(x.Symbol, methodSymbol) &&
x.Node.IsEquivalentTo(invocationExpression)
).ToImmutableList();
}
}
}
6 changes: 2 additions & 4 deletions MarshalInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ CancellationToken cancellationToken

public static MarshalInfo GetMarshalInfo
(
IMethodSymbol methodSymbol,
INamedTypeSymbol? marshaller,
InterfaceDescriptor interfaceDescriptor,
InvocationExpressionSyntax invocationExpression,
SemanticModel semanticModel,
CancellationToken cancellationToken
)
{
if (methodSymbol.TypeArguments.FirstOrDefault() is not INamedTypeSymbol marshaller)
if (marshaller is null)
{
return new(invocationExpression, semanticModel, cancellationToken);
}
Expand Down Expand Up @@ -160,8 +160,6 @@ CancellationToken cancellationToken
return null;
}

private MarshalInfo() { }

private MarshalInfo
(
InvocationExpressionSyntax invocationExpression,
Expand Down
Loading