diff --git a/ClosedGenericInterceptor.cs b/ClosedGenericInterceptor.cs index 89e14ab..643ad55 100644 --- a/ClosedGenericInterceptor.cs +++ b/ClosedGenericInterceptor.cs @@ -48,9 +48,11 @@ private string GetSourceText() } var method = InterceptsMethod; var typeParameters = Constants.InterceptorTypeParameters[method.ContainingInterface.Arity]; - typeParameters = typeParameters.Length != 0 ? - $"<{typeParameters}>" : - typeParameters; + typeParameters = method.Arity != 0 ? + $"<{typeParameters}, XMarshaller>" : + typeParameters.Length != 0 ? + $"<{typeParameters}>" : + typeParameters; _ = sb.Append ( $@"[MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -60,7 +62,7 @@ public static {method.ContainingInterface.FullName} {method.Name}{typeParameters ){Constants.InterceptorAntiConstraints[method.ContainingInterface.Arity]} {{" ); - if (ImplementationClass.Marshalling.StaticCallingConvention is not null) + if (ImplementationClass.MarshalInfo.StaticCallingConvention is not null) { _ = sb.Append ( diff --git a/Constants.Actions.cs b/Constants.Actions.cs index f1c2959..8ae744c 100644 --- a/Constants.Actions.cs +++ b/Constants.Actions.cs @@ -24,7 +24,7 @@ internal static class Actions $"{RootNamespace}.INativeAction`13", $"{RootNamespace}.INativeAction`14", $"{RootNamespace}.INativeAction`15", - $"{RootNamespace}.INativeAction`16", + $"{RootNamespace}.INativeAction`16" ]; public static readonly string[] QualifiedTypeParameters = Constants.QualifiedTypeParameters; diff --git a/Constants.Funcs.cs b/Constants.Funcs.cs index 88e494d..0d4b4bf 100644 --- a/Constants.Funcs.cs +++ b/Constants.Funcs.cs @@ -31,7 +31,7 @@ .. Constants.AntiConstraints.Skip(1) $"{RootNamespace}.INativeFunc`14", $"{RootNamespace}.INativeFunc`15", $"{RootNamespace}.INativeFunc`16", - $"{RootNamespace}.INativeFunc`17", + $"{RootNamespace}.INativeFunc`17" ]; public static readonly string[] QualifiedTypeParameters = diff --git a/Constants.cs b/Constants.cs index cea4273..97f8241 100644 --- a/Constants.cs +++ b/Constants.cs @@ -62,8 +62,9 @@ internal static partial class Constants public static readonly string[] InterceptorAntiConstraints = [ - .. AntiConstraints.Select(static x => x.Replace('T', 'X')), - $"{AntiConstraint_T1_T16.Replace('T', 'X')}{NewLineIndent2}where X17 : allows ref struct" + .. AntiConstraints.Select(static x => x.Replace(" where T", " where X").Replace('T', 'X')), + $"{AntiConstraint_T1_T16.Replace(" where T", " where X") + .Replace('T', 'X')}{NewLineIndent2}where X17 : allows ref struct" ]; public static readonly string[] Arguments = @@ -111,6 +112,9 @@ .. AntiConstraints.Select(static x => x.Replace('T', 'X')), public const string RootNamespace = "Monkeymoto.NativeGenericDelegates"; public const string DeclarationsSourceFileName = RootNamespace + ".Declarations.g.cs"; + public const string IMarshallerInterfaceName = "IMarshaller"; + public const string IMarshallerMetadataName = $"{RootNamespace}.{IMarshallerInterfaceName}`1"; + /// /// Returns the total number of interfaces per category (Action or Func). /// diff --git a/DelegateMarshalling.Parser.cs b/DelegateMarshalling.Parser.cs deleted file mode 100644 index d9f06e8..0000000 --- a/DelegateMarshalling.Parser.cs +++ /dev/null @@ -1,322 +0,0 @@ -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.Operations; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading; - -namespace Monkeymoto.NativeGenericDelegates -{ - internal sealed partial class DelegateMarshalling - { - internal static class Parser - { - private static IReadOnlyList GetMarshalAsCollectionFromArrayInitializer - ( - IArrayInitializerOperation initializer, - string parameterName, - int collectionLength, - IList diagnostics, - CancellationToken cancellationToken - ) - { - var results = new List(collectionLength); - foreach (var elementValue in initializer.ElementValues) - { - cancellationToken.ThrowIfCancellationRequested(); - var result = GetMarshalAsFromOperation - ( - elementValue, - parameterName, - diagnostics, - diagnosticTypeSuffix: "[]", - cancellationToken - ); - results.Add(result); - if (results.Count == collectionLength) - { - break; - } - } - return results.AsReadOnly(); - } - - private static IReadOnlyList? GetMarshalAsCollectionFromCollectionExpression - ( - ICollectionExpressionOperation collectionExpression, - string parameterName, - int collectionLength, - IList diagnostics, - CancellationToken cancellationToken - ) - { - if (collectionExpression.Elements.Any(x => x is ISpreadOperation)) - { - diagnostics.Add - ( - Diagnostic.Create - ( - Diagnostics.NGD1003_MarshalAsArgumentSpreadElementNotSupported, - collectionExpression.Syntax.GetLocation(), - parameterName - ) - ); - return null; - } - var elements = collectionExpression.Elements.Select(static x => x switch - { - ILiteralOperation => x, - IObjectCreationOperation => x, - IConversionOperation conversion => - conversion.Operand as IObjectCreationOperation, - _ => null - }).Where(static x => x is not null); - if (!elements.Any()) - { - diagnostics.Add - ( - Diagnostic.Create - ( - Diagnostics.NGD1001_MarshalAsArgumentMustUseObjectCreationSyntaxDescriptor, - collectionExpression.Syntax.GetLocation(), - "[]", - parameterName - ) - ); - } - var results = new List(collectionLength); - foreach (var element in elements) - { - cancellationToken.ThrowIfCancellationRequested(); - var result = GetMarshalAsFromOperation - ( - element!, - parameterName, - diagnostics, - diagnosticTypeSuffix: "[]", - cancellationToken - ); - results.Add(result); - if (results.Count == collectionLength) - { - break; - } - } - return results.AsReadOnly(); - } - - private static IReadOnlyList? GetMarshalAsCollectionFromOperation - ( - IOperation collection, - string parameterName, - int collectionLength, - IList diagnostics, - CancellationToken cancellationToken - ) - { - if ((collectionLength <= 0) || collection.ConstantValue.HasValue) - { - // empty collection or collection operation is `null` literal in source - return null; - } - if (collection is IArrayCreationOperation arrayCreation) - { - // new MarshalAsAttribute[] { ... } - if (arrayCreation.Initializer is not null) - { - return GetMarshalAsCollectionFromArrayInitializer - ( - arrayCreation.Initializer, - parameterName, - collectionLength, - diagnostics, - cancellationToken - ); - } - // else, no initializer or no arguments, default to no marshalling - } - else if (collection is IFieldReferenceOperation fieldReference && fieldReference.Field.IsReadOnly && - fieldReference.Type is IArrayTypeSymbol) - { - // readonly field of MarshalAsAttribute[] - return (IReadOnlyList?)GetMarshalAsFromField - ( - fieldReference, - parameterName, - collectionLength, - diagnostics, - cancellationToken - ); - } - else - { - var collectionExpression = collection switch - { - IConversionOperation conversion => conversion.Operand as ICollectionExpressionOperation, - ICollectionExpressionOperation op => op, - _ => null - }; - if (collectionExpression is not null) - { - return GetMarshalAsCollectionFromCollectionExpression - ( - collectionExpression, - parameterName, - collectionLength, - diagnostics, - cancellationToken - ); - } - else // unknown operation - report diagnostic - { - diagnostics.Add - ( - Diagnostic.Create - ( - Diagnostics.NGD1001_MarshalAsArgumentMustUseObjectCreationSyntaxDescriptor, - collection.Syntax.GetLocation(), - "[]", - parameterName - ) - ); - } - } - return null; - } - - public static string? GetMarshalAsFromOperation - ( - IOperation value, - string parameterName, - IList diagnostics, - string diagnosticTypeSuffix, - CancellationToken cancellationToken - ) - { - if (value.ConstantValue.HasValue) // value in source is `null` literal - { - return null; - } - if (value is IFieldReferenceOperation fieldReference && fieldReference.Field.IsReadOnly && - fieldReference.Type is not IArrayTypeSymbol) - { - return (string?)GetMarshalAsFromField - ( - fieldReference, - parameterName, - collectionLength: 1, - diagnostics, - cancellationToken - ); - } - IObjectCreationOperation? objectCreation = value switch - { - IConversionOperation conversion => - conversion.ChildOperations.OfType().FirstOrDefault(), - _ => value as IObjectCreationOperation - }; - if (objectCreation is null) - { - diagnostics.Add - ( - Diagnostic.Create - ( - Diagnostics.NGD1001_MarshalAsArgumentMustUseObjectCreationSyntaxDescriptor, - value.Syntax.GetLocation(), - diagnosticTypeSuffix, - parameterName - ) - ); - return null; - } - var sb = new StringBuilder(objectCreation.Arguments[0].Syntax.ToString()); - if (objectCreation.Initializer is not null) - { - _ = sb.Append(objectCreation.Initializer.Syntax.ToString()) - .Replace('{', ',') - .Replace("]", string.Empty); - int i = sb.Length - 1; - for ( ; (i >= 0) && char.IsWhiteSpace(sb[i]); --i) { } - sb.Length = i + 1; - } - return sb.ToString(); - } - - private static object? GetMarshalAsFromField - ( - IFieldReferenceOperation fieldReference, - string parameterName, - int collectionLength, - IList diagnostics, - CancellationToken cancellationToken - ) - { - var fieldDeclaration = fieldReference.Field.DeclaringSyntaxReferences[0].GetSyntax(cancellationToken); - var equalsValueClause = fieldDeclaration.ChildNodes().OfType() - .FirstOrDefault(); - SemanticModel? semanticModel = equalsValueClause is not null ? - fieldReference.SemanticModel!.Compilation.GetSemanticModel(equalsValueClause.SyntaxTree) : - null; - if (semanticModel?.GetOperation(equalsValueClause!, cancellationToken) is not - IFieldInitializerOperation fieldInitializer) - { - return null; - } - bool isArray = fieldReference.Field.Type is IArrayTypeSymbol; - if (isArray) - { - return GetMarshalAsCollectionFromOperation - ( - fieldInitializer.Value, - parameterName, - collectionLength, - diagnostics, - cancellationToken - ); - } - return GetMarshalAsFromOperation - ( - fieldInitializer.Value, - parameterName, - diagnostics, - diagnosticTypeSuffix: string.Empty, - cancellationToken - ); - } - - public static IReadOnlyList? GetMarshalParamsAs - ( - IArgumentOperation? marshalParamsAsArgument, - int invokeParamCount, - IList diagnostics, - CancellationToken cancellationToken - ) => marshalParamsAsArgument is not null ? - GetMarshalAsCollectionFromOperation - ( - marshalParamsAsArgument.Value, - marshalParamsAsArgument.Parameter!.Name, - invokeParamCount, - diagnostics, - cancellationToken - ) : - null; - - public static string? GetMarshalReturnAs - ( - IArgumentOperation? marshalReturnAsArgument, - IList diagnostics, - CancellationToken cancellationToken - ) => marshalReturnAsArgument is not null ? - GetMarshalAsFromOperation - ( - marshalReturnAsArgument.Value, - marshalReturnAsArgument.Parameter!.Name, - diagnostics, - diagnosticTypeSuffix: string.Empty, - cancellationToken - ) : - null; - } - } -} diff --git a/DelegateMarshalling.cs b/DelegateMarshalling.cs deleted file mode 100644 index c0d42a1..0000000 --- a/DelegateMarshalling.cs +++ /dev/null @@ -1,139 +0,0 @@ -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.Operations; -using System; -using System.Collections.Generic; -using System.Linq; -using System.Runtime.InteropServices; -using System.Threading; - -namespace Monkeymoto.NativeGenericDelegates -{ - internal sealed partial class DelegateMarshalling : IEquatable - { - private readonly int hashCode; - - public IReadOnlyList? MarshalParamsAs { get; } - public string? MarshalReturnAs { get; } - public string? RuntimeCallingConvention { get; } - public CallingConvention? StaticCallingConvention { get; } - - public static bool operator ==(DelegateMarshalling? left, DelegateMarshalling? right) => - left?.Equals(right) ?? right is null; - public static bool operator !=(DelegateMarshalling? left, DelegateMarshalling? right) => !(left == right); - - public DelegateMarshalling - ( - InvocationExpressionSyntax invocationExpression, - SemanticModel semanticModel, - InterfaceDescriptor interfaceDescriptor, - IList diagnostics, - CancellationToken cancellationToken - ) - { - IArgumentOperation? callingConventionArgument = null; - IArgumentOperation? marshalMapArgument = null; - IArgumentOperation? marshalReturnAsArgument = null; - IArgumentOperation? marshalParamsAsArgument = null; - foreach (var argumentNode in invocationExpression.ArgumentList.Arguments) - { - var argumentOp = (IArgumentOperation)semanticModel.GetOperation(argumentNode, cancellationToken)!; - switch (argumentOp.Parameter!.Name) - { - case "callingConvention": - callingConventionArgument = argumentOp; - break; - case "marshalMap": - IConversionOperation? conversion = argumentOp.Value as IConversionOperation; - ILiteralOperation? literal = conversion?.Operand as ILiteralOperation; - if (!argumentOp.ConstantValue.HasValue && - !(conversion?.ConstantValue.HasValue ?? false) && - !(literal?.ConstantValue.HasValue ?? false)) - { - marshalMapArgument = argumentOp; - } - break; - case "marshalReturnAs": - marshalReturnAsArgument = argumentOp; - break; - case "marshalParamsAs": - marshalParamsAsArgument = argumentOp; - break; - default: - break; - } - } - MarshalParamsAs = Parser.GetMarshalParamsAs - ( - marshalParamsAsArgument, - interfaceDescriptor.InvokeParameterCount, - diagnostics, - cancellationToken - ); - MarshalReturnAs = Parser.GetMarshalReturnAs(marshalReturnAsArgument, diagnostics, cancellationToken); - var marshalMap = MarshalMap.Parse(marshalMapArgument, diagnostics, cancellationToken); - if (marshalMap is not null) - { - var marshalParamsList = new List(interfaceDescriptor.InvokeParameterCount); - if (MarshalParamsAs is not null) - { - marshalParamsList.AddRange(MarshalParamsAs); - } - while (marshalParamsList.Count < interfaceDescriptor.InvokeParameterCount) - { - marshalParamsList.Add(null); - } - bool dirty = false; - for (int i = 0; i < interfaceDescriptor.InvokeParameterCount; ++i) - { - if ((marshalParamsList[i] is null) && - marshalMap.TryGetValue(interfaceDescriptor.TypeArguments[i], out var marshalParamAs)) - { - marshalParamsList[i] = marshalParamAs; - dirty = true; - } - } - if (dirty) - { - MarshalParamsAs = marshalParamsList.AsReadOnly(); - } - if ((MarshalReturnAs is null) && !interfaceDescriptor.IsAction && - marshalMap.TryGetValue(interfaceDescriptor.TypeArguments.Last(), out var marshalReturnAs)) - { - MarshalReturnAs = marshalReturnAs; - } - } - if (callingConventionArgument is not null) - { - var value = callingConventionArgument.Value; - var field = (value as IFieldReferenceOperation)?.Field; - if (field is not null && SymbolEqualityComparer.Default.Equals(field.ContainingType, field.Type) && - Enum.TryParse(field?.Name, false, out CallingConvention callingConvention)) - { - RuntimeCallingConvention = null; - StaticCallingConvention = callingConvention; - } - else - { - RuntimeCallingConvention = value.ToString(); - StaticCallingConvention = null; - } - } - hashCode = Hash.Combine - ( - MarshalParamsAs, - MarshalReturnAs, - RuntimeCallingConvention, - StaticCallingConvention - ); - } - - public override bool Equals(object? obj) => obj is DelegateMarshalling other && Equals(other); - public bool Equals(DelegateMarshalling? other) => - (other is not null) && MarshalParamsAs.SequenceEqual(other.MarshalParamsAs) && - (MarshalReturnAs == other.MarshalReturnAs) && - (RuntimeCallingConvention == other.RuntimeCallingConvention) && - (StaticCallingConvention == other.StaticCallingConvention); - public override int GetHashCode() => hashCode; - } -} diff --git a/Generator.cs b/Generator.cs index 841b186..52b0ada 100644 --- a/Generator.cs +++ b/Generator.cs @@ -33,21 +33,18 @@ public void Initialize(IncrementalGeneratorInitializationContext context) { context.RegisterPostInitializationOutput(static context => context.AddSource(Constants.DeclarationsSourceFileName, PostInitialization.GetSourceText())); - var interfaceSymbols = InterfaceSymbolCollection.GetSymbols(context.CompilationProvider); - var interfaceReferences = InterfaceReferenceCollection.GetReferences(context, interfaceSymbols); - var methodReferences = MethodReferenceCollection.GetReferences(interfaceReferences); - context.RegisterSourceOutput(methodReferences, static (context, methodReferences) => - { - foreach (var diagnostic in methodReferences.Diagnostics) - { - context.ReportDiagnostic(diagnostic); - } - }); + var interfaceOrMethodSymbols = InterfaceOrMethodSymbolCollection.GetSymbols(context.CompilationProvider); + var interfaceOrMethodReferences = + InterfaceOrMethodReferenceCollection.GetReferences(context, interfaceOrMethodSymbols); + var methodReferences = MethodReferenceCollection.GetReferences(interfaceOrMethodReferences); var implementationClasses = ImplementationClassCollection.GetImplementationClasses(methodReferences); - context.RegisterImplementationSourceOutput(implementationClasses, static (context, implementationClasses) => - { - var sb = new StringBuilder - ( + context.RegisterImplementationSourceOutput + ( + implementationClasses, + static (context, implementationClasses) => + { + var sb = new StringBuilder + ( $@"// using System; using System.Collections.Generic; @@ -57,18 +54,19 @@ public void Initialize(IncrementalGeneratorInitializationContext context) namespace {Constants.RootNamespace} {{" - ); - foreach (var implementationClass in implementationClasses) - { - _ = sb.Append(Constants.NewLineIndent1).Append(implementationClass.SourceText); + ); + foreach (var implementationClass in implementationClasses) + { + _ = sb.Append(Constants.NewLineIndent1).Append(implementationClass.SourceText); + } + _ = sb.AppendLine(implementationClasses.GetOpenGenericInterceptorsSourceText()); + int i = sb.Length - 1; + for ( ; (i >= 0) && char.IsWhiteSpace(sb[i]); --i) { } + sb.Length = i + 1; + _ = sb.AppendLine().AppendLine("}"); + context.AddSource(Constants.SourceFileName, sb.ToString()); } - _ = sb.AppendLine(implementationClasses.GetOpenGenericInterceptorsSourceText()); - int i = sb.Length - 1; - for (; (i >= 0) && char.IsWhiteSpace(sb[i]); --i) { } - sb.Length = i + 1; - _ = sb.AppendLine().AppendLine("}"); - context.AddSource(Constants.SourceFileName, sb.ToString()); - }); + ); } } } diff --git a/ImplementationClass.ClassID.cs b/ImplementationClass.ClassID.cs index b8e5ff6..81d56cb 100644 --- a/ImplementationClass.ClassID.cs +++ b/ImplementationClass.ClassID.cs @@ -4,9 +4,10 @@ namespace Monkeymoto.NativeGenericDelegates { internal sealed partial class ImplementationClass { - public readonly struct ClassID(MethodDescriptor method, DelegateMarshalling marshalling) : IEquatable + public readonly struct ClassID(MethodDescriptor method, int invocationArgumentCount, MarshalInfo marshalInfo) : + IEquatable { - private readonly int hashCode = Hash.Combine(method, marshalling); + private readonly int hashCode = Hash.Combine(method, invocationArgumentCount, marshalInfo); public static bool operator ==(ClassID left, ClassID right) => left.Equals(right); public static bool operator !=(ClassID left, ClassID right) => !(left == right); diff --git a/ImplementationClass.cs b/ImplementationClass.cs index 24902b5..114207b 100644 --- a/ImplementationClass.cs +++ b/ImplementationClass.cs @@ -1,6 +1,5 @@ using System; using System.Collections.Generic; -using System.Linq; using System.Runtime.InteropServices; using System.Text; @@ -13,7 +12,7 @@ internal sealed partial class ImplementationClass : IEquatable methodReferences ) { var category = method.ContainingInterface.Category; - ID = new(method, marshalling); + ID = new(method, invocationArgumentCount, marshalInfo); ClassName = $"Native{category}_{ID}"; - Marshalling = marshalling; + MarshalInfo = marshalInfo; Method = method; if (!isInterfaceOrMethodOpenGeneric) { @@ -80,7 +80,7 @@ private string GetInvokeParameters() { return "()"; } - var marshalParamsAs = Marshalling.MarshalParamsAs ?? []; + var marshalParamsAs = MarshalInfo.MarshalParamsAs ?? []; var typeArguments = Method.ContainingInterface.TypeArguments; var sb = new StringBuilder($"{Constants.NewLineIndent2}({Constants.NewLineIndent3}"); for (int i = 0, j = 1; i < invokeParameterCount; ++i, ++j) @@ -105,10 +105,9 @@ private string GetInvokeParameters() private string GetSourceText() { - if (Marshalling.StaticCallingConvention is not null) + if (MarshalInfo.StaticCallingConvention is not null) { - return - GetSourceText(Marshalling.StaticCallingConvention.Value); + return GetSourceText(MarshalInfo.StaticCallingConvention.Value); } var interceptor = Interceptor?.SourceText; if (interceptor is not null) @@ -135,23 +134,12 @@ private string GetSourceText(CallingConvention callingConvention, string? classS var interfaceFullName = Method.ContainingInterface.FullName; var invokeParameterCount = Method.ContainingInterface.InvokeParameterCount; var invokeParameters = GetInvokeParameters(); - var returnMarshalAsAttribute = Marshalling.MarshalReturnAs is not null ? - $"[return: MarshalAs({Marshalling.MarshalReturnAs})]{Constants.NewLineIndent2}" : + var returnMarshalAsAttribute = MarshalInfo.MarshalReturnAs is not null ? + $"[return: MarshalAs({MarshalInfo.MarshalReturnAs})]{Constants.NewLineIndent2}" : string.Empty; - string? returnKeyword; - string? returnType; - switch (Method.ContainingInterface.IsAction) - { - case true: - returnKeyword = string.Empty; - returnType = "void"; - break; - default: - returnKeyword = "return "; - returnType = Method.ContainingInterface.TypeArguments.Last(); - break; - } - var interceptor = Marshalling.StaticCallingConvention is not null ? + var returnKeyword = Method.ContainingInterface.ReturnKeyword; + var returnType = Method.ContainingInterface.ReturnType; + var interceptor = MarshalInfo.StaticCallingConvention is not null ? Interceptor?.SourceText ?? string.Empty : string.Empty; if (interceptor != string.Empty) diff --git a/ImplementationClassCollection.Key.cs b/ImplementationClassCollection.Key.cs index 2422220..b85f730 100644 --- a/ImplementationClassCollection.Key.cs +++ b/ImplementationClassCollection.Key.cs @@ -1,5 +1,4 @@ using System; -using System.Diagnostics; namespace Monkeymoto.NativeGenericDelegates { @@ -20,12 +19,18 @@ internal sealed partial class ImplementationClassCollection public Key(MethodReference methodReference) { MethodReference = methodReference; - hashCode = Hash.Combine(MethodReference.Method, MethodReference.Marshalling); + hashCode = Hash.Combine + ( + MethodReference.Method, + MethodReference.InvocationArgumentCount, + MethodReference.MarshalInfo + ); } public override bool Equals(object? obj) => obj is Key other && Equals(other); public bool Equals(Key other) => (MethodReference.Method == other.MethodReference.Method) && - (MethodReference.Marshalling == other.MethodReference.Marshalling); + (MethodReference.InvocationArgumentCount == other.MethodReference.InvocationArgumentCount) && + (MethodReference.MarshalInfo == other.MethodReference.MarshalInfo); public override int GetHashCode() => hashCode; } } diff --git a/ImplementationClassCollection.cs b/ImplementationClassCollection.cs index 0b900c8..507848a 100644 --- a/ImplementationClassCollection.cs +++ b/ImplementationClassCollection.cs @@ -45,7 +45,8 @@ .. dictionary.Select openGenericInterceptorsBuilder, x.Key.MethodReference.Method, x.Key.MethodReference.IsInterfaceOrMethodOpenGeneric, - x.Key.MethodReference.Marshalling, + x.Key.MethodReference.MarshalInfo, + x.Key.MethodReference.InvocationArgumentCount, x.Value.AsReadOnly() ) ) diff --git a/InterceptedMethodReference.cs b/InterceptedMethodReference.cs index 1211e40..4aae0ed 100644 --- a/InterceptedMethodReference.cs +++ b/InterceptedMethodReference.cs @@ -18,13 +18,13 @@ public InterceptedMethodReference(MethodReference methodReference, ClosedGeneric { Interceptor = interceptor; MethodReference = methodReference; - hashCode = Hash.Combine(MethodReference.Method, MethodReference.Marshalling); + hashCode = Hash.Combine(MethodReference.Method, MethodReference.MarshalInfo); } public override bool Equals(object? obj) => obj is InterceptedMethodReference other && Equals(other); public bool Equals(InterceptedMethodReference? other) => (other is not null) && (MethodReference.Method == other.MethodReference.Method) && - (MethodReference.Marshalling == other.MethodReference.Marshalling); + (MethodReference.MarshalInfo == other.MethodReference.MarshalInfo); public override int GetHashCode() => hashCode; } } diff --git a/InterfaceDescriptor.cs b/InterfaceDescriptor.cs index 7337559..ca30aac 100644 --- a/InterfaceDescriptor.cs +++ b/InterfaceDescriptor.cs @@ -15,6 +15,8 @@ internal sealed class InterfaceDescriptor : IEquatable public int InvokeParameterCount { get; } public bool IsAction { get; } public string Name { get; } + public string ReturnKeyword { get; } + public string ReturnType { get; } public IReadOnlyList TypeArguments { get; } public string TypeArgumentList { get; } @@ -26,15 +28,26 @@ public InterfaceDescriptor(INamedTypeSymbol interfaceSymbol) { bool isAction = interfaceSymbol.Name.Contains(Constants.CategoryAction); Arity = interfaceSymbol.Arity; - Category = isAction ? Constants.CategoryAction : Constants.CategoryFunc; - IsAction = isAction; - InvokeParameterCount = interfaceSymbol.Arity - (isAction ? 0 : 1); + InvokeParameterCount = Arity - (isAction ? 0 : 1); Name = interfaceSymbol.Name; - TypeArguments = [.. interfaceSymbol.TypeArguments.Select(x => x.ToDisplayString())]; - TypeArgumentList = - Arity == 0 ? - string.Empty : - $"<{string.Join(", ", TypeArguments)}>"; + TypeArguments = [.. interfaceSymbol.TypeArguments.Select(static x => x.ToDisplayString())]; + if (isAction) + { + Category = Constants.CategoryAction; + IsAction = true; + ReturnKeyword = string.Empty; + ReturnType = "void"; + } + else + { + Category = Constants.CategoryFunc; + IsAction = false; + ReturnKeyword = "return "; + ReturnType = TypeArguments.Last(); + } + TypeArgumentList = Arity != 0 ? + $"<{string.Join(", ", TypeArguments)}>" : + string.Empty; FullName = $"{Name}{TypeArgumentList}"; hashCode = Hash.Combine(Arity, FullName); } diff --git a/InterfaceOrMethodReferenceCollection.cs b/InterfaceOrMethodReferenceCollection.cs new file mode 100644 index 0000000..286cf29 --- /dev/null +++ b/InterfaceOrMethodReferenceCollection.cs @@ -0,0 +1,157 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Monkeymoto.GeneratorUtils; +using System; +using System.Collections; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics; +using System.Linq; +using System.Threading; + +namespace Monkeymoto.NativeGenericDelegates +{ + internal readonly struct InterfaceOrMethodReferenceCollection : + IEquatable, + IEnumerable + { + private readonly int hashCode; + private readonly ImmutableHashSet references; + + public static bool operator == + ( + InterfaceOrMethodReferenceCollection left, + InterfaceOrMethodReferenceCollection right + ) => left.Equals(right); + + public static bool operator != + ( + InterfaceOrMethodReferenceCollection left, + InterfaceOrMethodReferenceCollection right + ) => !(left == right); + + private static void AddMethodReferences + ( + ImmutableHashSet.Builder references, + GenericSymbolReferenceTree tree, + IMethodSymbol methodSymbol, + CancellationToken cancellationToken + ) + { + var methodReferences = tree.GetBranchesBySymbol(methodSymbol, cancellationToken); + if (!methodReferences.Any()) + { + return; + } + var roots = tree.GetBranchesBySymbol(methodSymbol.ContainingType, cancellationToken); + if (!roots.Any()) + { + return; + } + var methodSymbolsToConstruct = new HashSet(SymbolEqualityComparer.Default); + foreach (var root in roots.Select(static x => (INamedTypeSymbol)x.Symbol)) + { + cancellationToken.ThrowIfCancellationRequested(); + var methodSymbolToConstruct = root.GetMembers(methodSymbol.Name) + .Where + ( + x => SymbolEqualityComparer.Default.Equals + ( + x.OriginalDefinition, + methodSymbol.OriginalDefinition + ) + ) + .Cast() + .Single(); + _ = methodSymbolsToConstruct.Add(methodSymbolToConstruct); + } + var constructedMethodSymbols = new HashSet(SymbolEqualityComparer.Default); + constructedMethodSymbols.UnionWith(methodReferences.SelectMany + ( + x => methodSymbolsToConstruct.Select + ( + y => y.Construct([.. ((IMethodSymbol)x.Symbol).TypeArguments]) + ) + )); + foreach (var methodReference in methodReferences) + { + cancellationToken.ThrowIfCancellationRequested(); + switch (methodReference.Node) + { + case IdentifierNameSyntax identifierName: + references.UnionWith + ( + constructedMethodSymbols.Select + ( + x => new GenericSymbolReference + ( + x, + methodReference.SemanticModel, + identifierName + ) + ) + ); + break; + case InvocationExpressionSyntax invocationExpression: + references.UnionWith + ( + constructedMethodSymbols.Select + ( + x => new GenericSymbolReference + ( + x, + methodReference.SemanticModel, + invocationExpression + ) + ) + ); + break; + default: + throw new UnreachableException(); + } + } + } + + public static IncrementalValueProvider GetReferences + ( + IncrementalGeneratorInitializationContext context, + IncrementalValueProvider symbolsProvider + ) + { + var treeProvider = GenericSymbolReferenceTree.FromIncrementalGeneratorInitializationContext(context); + return symbolsProvider.Combine(treeProvider).Select(static (x, cancellationToken) => + { + var symbols = x.Left; + using var tree = x.Right; // Dispose tree after we extract the symbol references we need + var references = ImmutableHashSet.CreateBuilder(); + foreach (var symbol in symbols) + { + switch (symbol) + { + case INamedTypeSymbol: + references.UnionWith(tree.GetBranchesBySymbol(symbol, cancellationToken)); + break; + case IMethodSymbol methodSymbol: + AddMethodReferences(references, tree, methodSymbol, cancellationToken); + break; + default: + throw new UnreachableException(); + } + } + return new InterfaceOrMethodReferenceCollection(references.ToImmutable()); + }); + } + + private InterfaceOrMethodReferenceCollection(ImmutableHashSet references) + { + this.references = references; + hashCode = Hash.Combine(references); + } + + public override bool Equals(object? obj) => obj is InterfaceOrMethodReferenceCollection other && Equals(other); + public bool Equals(InterfaceOrMethodReferenceCollection other) => references.SetEquals(other.references); + public IEnumerator GetEnumerator() => references.GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + public override int GetHashCode() => hashCode; + } +} diff --git a/InterfaceSymbolCollection.cs b/InterfaceOrMethodSymbolCollection.cs similarity index 52% rename from InterfaceSymbolCollection.cs rename to InterfaceOrMethodSymbolCollection.cs index 307e307..dd24610 100644 --- a/InterfaceSymbolCollection.cs +++ b/InterfaceOrMethodSymbolCollection.cs @@ -8,44 +8,62 @@ namespace Monkeymoto.NativeGenericDelegates { - internal readonly struct InterfaceSymbolCollection : - IEquatable, + internal readonly struct InterfaceOrMethodSymbolCollection : + IEquatable, IEnumerable { private readonly int hashCode; private readonly ImmutableList symbols; - public static bool operator ==(InterfaceSymbolCollection left, InterfaceSymbolCollection right) => - left.Equals(right); - public static bool operator !=(InterfaceSymbolCollection left, InterfaceSymbolCollection right) => - !(left == right); + public static bool operator == + ( + InterfaceOrMethodSymbolCollection left, + InterfaceOrMethodSymbolCollection right + ) =>left.Equals(right); + + public static bool operator != + ( + InterfaceOrMethodSymbolCollection left, + InterfaceOrMethodSymbolCollection right + ) => !(left == right); - public static IncrementalValueProvider GetSymbols + public static IncrementalValueProvider GetSymbols ( IncrementalValueProvider compilationProvider ) => compilationProvider.Select ( - static (compilation, cancellationToken) => new InterfaceSymbolCollection(compilation, cancellationToken) + static (compilation, cancellationToken) => + new InterfaceOrMethodSymbolCollection(compilation, cancellationToken) ); - public InterfaceSymbolCollection(Compilation compilation, CancellationToken cancellationToken) + public InterfaceOrMethodSymbolCollection(Compilation compilation, CancellationToken cancellationToken) { var builder = ImmutableList.CreateBuilder(); for (int i = 0; i < Constants.InterfaceSymbolCountPerCategory; ++i) { cancellationToken.ThrowIfCancellationRequested(); var interfaceSymbol = compilation.GetTypeByMetadataName(Constants.Actions.MetadataNames[i])!; + // GetMembers *seems* to be ordered, but this is not a documented part of the API + // explicitly order the members for Equals comparisons + var genericMethods = interfaceSymbol.GetMembers() + .Where(static x => x is IMethodSymbol methodSymbol && methodSymbol.IsGenericMethod) + .OrderBy(static x => x.Name); builder.Add(interfaceSymbol); + builder.AddRange(genericMethods); interfaceSymbol = compilation.GetTypeByMetadataName(Constants.Funcs.MetadataNames[i])!; + genericMethods = interfaceSymbol.GetMembers() + .Where(static x => x is IMethodSymbol methodSymbol && methodSymbol.IsGenericMethod) + .OrderBy(static x => x.Name); builder.Add(interfaceSymbol); + builder.AddRange(genericMethods); } symbols = builder.ToImmutable(); hashCode = Hash.Combine(symbols); } - public override bool Equals(object? obj) => obj is InterfaceSymbolCollection other && Equals(other); + public override bool Equals(object? obj) => obj is InterfaceOrMethodSymbolCollection other && Equals(other); - public bool Equals(InterfaceSymbolCollection other) + public bool Equals(InterfaceOrMethodSymbolCollection other) { if (symbols.Count != other.symbols.Count) { diff --git a/InterfaceReferenceCollection.cs b/InterfaceReferenceCollection.cs deleted file mode 100644 index 78fc4a7..0000000 --- a/InterfaceReferenceCollection.cs +++ /dev/null @@ -1,53 +0,0 @@ -using Microsoft.CodeAnalysis; -using Monkeymoto.GeneratorUtils; -using System; -using System.Collections; -using System.Collections.Generic; -using System.Collections.Immutable; -using System.Linq; - -namespace Monkeymoto.NativeGenericDelegates -{ - internal readonly struct InterfaceReferenceCollection : - IEquatable, - IEnumerable - { - private readonly int hashCode; - private readonly ImmutableHashSet references; - - public static bool operator ==(InterfaceReferenceCollection left, InterfaceReferenceCollection right) => - left.Equals(right); - - public static bool operator !=(InterfaceReferenceCollection left, InterfaceReferenceCollection right) => - !(left == right); - - public static IncrementalValueProvider GetReferences - ( - IncrementalGeneratorInitializationContext context, - IncrementalValueProvider symbolsProvider - ) - { - var treeProvider = GenericSymbolReferenceTree.FromIncrementalGeneratorInitializationContext(context); - return symbolsProvider.Combine(treeProvider).Select(static (x, cancellationToken) => - { - var symbols = x.Left; - using var tree = x.Right; // Dispose tree after we extract the symbol references we need - var references = ImmutableHashSet.CreateBuilder(); - references.UnionWith(symbols.SelectMany(x => tree.GetBranchesBySymbol(x, cancellationToken))); - return new InterfaceReferenceCollection(references.ToImmutable()); - }); - } - - private InterfaceReferenceCollection(ImmutableHashSet references) - { - this.references = references; - hashCode = Hash.Combine(references); - } - - public override bool Equals(object? obj) => obj is InterfaceReferenceCollection other && Equals(other); - public bool Equals(InterfaceReferenceCollection other) => references.SetEquals(other.references); - public IEnumerator GetEnumerator() => references.GetEnumerator(); - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - public override int GetHashCode() => hashCode; - } -} diff --git a/MarshalInfo.Parser.cs b/MarshalInfo.Parser.cs new file mode 100644 index 0000000..86ff686 --- /dev/null +++ b/MarshalInfo.Parser.cs @@ -0,0 +1,118 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Operations; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Monkeymoto.NativeGenericDelegates +{ + internal sealed partial class MarshalInfo + { + internal static class Parser + { + private static IReadOnlyList? GetMarshalAsCollectionFromElements + ( + IEnumerable? elements, + int collectionLength + ) + { + if (elements is null) + { + return null; + } + var results = new List(collectionLength); + foreach (var elementValue in elements) + { + var result = GetMarshalAsFromOperation(elementValue); + results.Add(result); + if (results.Count == collectionLength) + { + break; + } + } + return results.AsReadOnly(); + } + + private static IReadOnlyList? GetMarshalAsCollectionFromCollectionExpression + ( + ICollectionExpressionOperation? collectionExpression, + int collectionLength + ) + { + if ((collectionExpression is null) || collectionExpression.Elements.Any(x => x is ISpreadOperation)) + { + return null; + } + var elements = collectionExpression.Elements.Select(static x => x switch + { + ILiteralOperation => x, + IObjectCreationOperation => x, + IConversionOperation conversion => + conversion.Operand as IObjectCreationOperation, + _ => null + }).Where(static x => x is not null); + return GetMarshalAsCollectionFromElements(elements, collectionLength); + } + + private static IReadOnlyList? GetMarshalAsCollectionFromOperation + ( + IOperation? collection, + int collectionLength + ) + { + if ((collectionLength <= 0) || (collection is null) || collection.ConstantValue.HasValue) + { + // empty collection or collection operation is `null` literal in source + return null; + } + return collection switch + { + IArrayCreationOperation arrayCreation => + GetMarshalAsCollectionFromElements(arrayCreation.Initializer?.ElementValues, collectionLength), + IConversionOperation conversion => + GetMarshalAsCollectionFromCollectionExpression + ( + conversion.Operand as ICollectionExpressionOperation, + collectionLength + ), + ICollectionExpressionOperation collectionExpression => + GetMarshalAsCollectionFromCollectionExpression(collectionExpression, collectionLength), + _ => null + }; + } + + public static string? GetMarshalAsFromOperation(IOperation? value) + { + var objectCreation = value switch + { + IConversionOperation conversion => conversion.Operand as IObjectCreationOperation, + _ => value as IObjectCreationOperation + }; + if (objectCreation is null) + { + return null; + } + var sb = new StringBuilder(objectCreation.Arguments[0].Syntax.ToString()); + if (objectCreation.Initializer is not null) + { + _ = sb.Append(objectCreation.Initializer.Syntax.ToString()) + .Replace('{', ',') + .Replace("}", string.Empty); + int i = sb.Length - 1; + for ( ; (i >= 0) && char.IsWhiteSpace(sb[i]); --i) { } + sb.Length = i + 1; + } + return sb.ToString(); + } + + public static IReadOnlyList? GetMarshalParamsAs + ( + IObjectCreationOperation? marshalParamsAs, + int invokeParamCount + ) => GetMarshalAsCollectionFromOperation(marshalParamsAs, invokeParamCount); + + public static string? GetMarshalReturnAs(IObjectCreationOperation? marshalReturnAs) => + GetMarshalAsFromOperation(marshalReturnAs); + } + } +} diff --git a/MarshalInfo.cs b/MarshalInfo.cs new file mode 100644 index 0000000..4881f3d --- /dev/null +++ b/MarshalInfo.cs @@ -0,0 +1,243 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Operations; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Runtime.InteropServices; +using System.Threading; + +namespace Monkeymoto.NativeGenericDelegates +{ + internal sealed partial class MarshalInfo : IEquatable + { + private readonly int hashCode; + + public IReadOnlyList? MarshalParamsAs { get; } + public string? MarshalReturnAs { get; } + public CallingConvention? StaticCallingConvention { get; } + + public static bool operator ==(MarshalInfo? left, MarshalInfo? right) => + left?.Equals(right) ?? right is null; + public static bool operator !=(MarshalInfo? left, MarshalInfo? right) => !(left == right); + + private static IFieldReferenceOperation? GetCallingConventionOperation + ( + InvocationExpressionSyntax invocationExpression, + SemanticModel semanticModel, + CancellationToken cancellationToken, + IPropertySymbol? property = null, + Compilation? compilation = null + ) + { + var callingConventionArg = invocationExpression.ArgumentList.Arguments + .Select(x => semanticModel.GetOperation(x, cancellationToken) as IArgumentOperation) + .Where(static x => (x is not null) && (x.Parameter!.Name == "callingConvention")) + .FirstOrDefault(); + return callingConventionArg is not null ? + callingConventionArg.Value as IFieldReferenceOperation : + GetFieldReferenceOperation(property, compilation, cancellationToken); + } + + private static IFieldReferenceOperation? GetFieldReferenceOperation + ( + IPropertySymbol? property, + Compilation? compilation, + CancellationToken cancellationToken + ) => GetOperation(property, compilation, cancellationToken); + + public static MarshalInfo GetMarshalInfo + ( + IMethodSymbol methodSymbol, + InterfaceDescriptor interfaceDescriptor, + InvocationExpressionSyntax invocationExpression, + SemanticModel semanticModel, + CancellationToken cancellationToken + ) + { + if (methodSymbol.TypeArguments.FirstOrDefault() is not INamedTypeSymbol marshaller) + { + return new(invocationExpression, semanticModel, cancellationToken); + } + var compilation = semanticModel.Compilation; + var marshallerInterface = compilation.GetTypeByMetadataName(Constants.IMarshallerMetadataName)!; + var properties = marshaller.GetMembers() + .OfType() + .Where + ( + x => x.ExplicitInterfaceImplementations.FirstOrDefault() is IPropertySymbol prop && + SymbolEqualityComparer.Default.Equals + ( + prop.ContainingType.OriginalDefinition, + marshallerInterface + ) + ); + IPropertySymbol? callingConventionProperty = null; + IPropertySymbol? marshalMapProperty = null; + IPropertySymbol? marshalParamsAsProperty = null; + IPropertySymbol? marshalReturnAsProperty = null; + foreach (var property in properties) + { + var name = property.Name.Substring(property.Name.LastIndexOf('.') + 1); + switch (name) + { + case "CallingConvention": + callingConventionProperty = property; + break; + case "MarshalMap": + marshalMapProperty = property; + break; + case "MarshalParamsAs": + marshalParamsAsProperty = property; + break; + case "MarshalReturnAs": + marshalReturnAsProperty = property; + break; + default: + throw new UnreachableException(); + } + } + var callingConventionOp = GetCallingConventionOperation + ( + invocationExpression, + semanticModel, + cancellationToken, + callingConventionProperty, + compilation + ); + var marshalMapOp = GetObjectCreationOperation(marshalMapProperty, compilation, cancellationToken); + var marshalParamsAsOp = + GetObjectCreationOperation(marshalParamsAsProperty, compilation, cancellationToken); + var marshalReturnAsOp = + GetObjectCreationOperation(marshalReturnAsProperty, compilation, cancellationToken); + return new + ( + interfaceDescriptor, + callingConventionOp, + marshalMapOp, + marshalParamsAsOp, + marshalReturnAsOp + ); + } + + private static IObjectCreationOperation? GetObjectCreationOperation + ( + IPropertySymbol? property, + Compilation compilation, + CancellationToken cancellationToken + ) => GetOperation(property, compilation, cancellationToken); + + private static T? GetOperation + ( + IPropertySymbol? property, + Compilation? compilation, + CancellationToken cancellationToken + ) + where T : class, IOperation + { + if ((property is null) || (compilation is null)) + { + return null; + } + var node = property.DeclaringSyntaxReferences[0].GetSyntax(cancellationToken); + var semanticModel = compilation.GetSemanticModel(node.SyntaxTree); + return node.DescendantNodesAndSelf() + .Select(x => semanticModel.GetOperation(x, cancellationToken)) + .Where(static x => x is not null) + .Select(static x => x.DescendantsAndSelf().OfType().FirstOrDefault()) + .FirstOrDefault(); + } + + private static CallingConvention? GetStaticCallingConvention(IFieldReferenceOperation? callingConventionOp) + { + var field = callingConventionOp?.Field; + if (field is not null && SymbolEqualityComparer.Default.Equals(field.ContainingType, field.Type) && + Enum.TryParse(field?.Name, false, out CallingConvention callingConvention)) + { + return callingConvention; + } + return null; + } + + private MarshalInfo() { } + + private MarshalInfo + ( + InvocationExpressionSyntax invocationExpression, + SemanticModel semanticModel, + CancellationToken cancellationToken + ) + { + var operation = + GetCallingConventionOperation(invocationExpression, semanticModel, cancellationToken); + StaticCallingConvention = GetStaticCallingConvention(operation); + hashCode = Hash.Combine + ( + MarshalParamsAs, + MarshalReturnAs, + StaticCallingConvention + ); + } + + private MarshalInfo + ( + InterfaceDescriptor interfaceDescriptor, + IFieldReferenceOperation? callingConventionOp, + IObjectCreationOperation? marshalMapCreation, + IObjectCreationOperation? marshalParamsAsCreation, + IObjectCreationOperation? marshalReturnAsCreation + ) + { + StaticCallingConvention = GetStaticCallingConvention(callingConventionOp); + MarshalParamsAs = + Parser.GetMarshalParamsAs(marshalParamsAsCreation, interfaceDescriptor.InvokeParameterCount); + MarshalReturnAs = Parser.GetMarshalReturnAs(marshalReturnAsCreation); + var marshalMap = MarshalMap.Parse(marshalMapCreation); + if (marshalMap is not null) + { + var marshalParamsList = new List(interfaceDescriptor.InvokeParameterCount); + if (MarshalParamsAs is not null) + { + marshalParamsList.AddRange(MarshalParamsAs); + } + while (marshalParamsList.Count < interfaceDescriptor.InvokeParameterCount) + { + marshalParamsList.Add(null); + } + bool dirty = false; + for (int i = 0; i < interfaceDescriptor.InvokeParameterCount; ++i) + { + if ((marshalParamsList[i] is null) && + marshalMap.TryGetValue(interfaceDescriptor.TypeArguments[i], out var marshalParamAs)) + { + marshalParamsList[i] = marshalParamAs; + dirty = true; + } + } + if (dirty) + { + MarshalParamsAs = marshalParamsList.AsReadOnly(); + } + if ((MarshalReturnAs is null) && !interfaceDescriptor.IsAction && + marshalMap.TryGetValue(interfaceDescriptor.ReturnType, out var marshalReturnAs)) + { + MarshalReturnAs = marshalReturnAs; + } + } + hashCode = Hash.Combine + ( + MarshalParamsAs, + MarshalReturnAs, + StaticCallingConvention + ); + } + + public override bool Equals(object? obj) => obj is MarshalInfo other && Equals(other); + public bool Equals(MarshalInfo? other) => + (other is not null) && + (MarshalParamsAs?.SequenceEqual(other.MarshalParamsAs) ?? other.MarshalParamsAs is null) && + (MarshalReturnAs == other.MarshalReturnAs) && (StaticCallingConvention == other.StaticCallingConvention); + public override int GetHashCode() => hashCode; + } +} diff --git a/MarshalMap.cs b/MarshalMap.cs index 5a9c550..ca79c81 100644 --- a/MarshalMap.cs +++ b/MarshalMap.cs @@ -1,11 +1,8 @@ using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Operations; using System.Collections; using System.Collections.Generic; using System.Collections.Immutable; -using System.Linq; -using System.Threading; namespace Monkeymoto.NativeGenericDelegates { @@ -18,43 +15,14 @@ internal sealed class MarshalMap : IReadOnlyDictionary public IEnumerable Keys => map.Keys; public IEnumerable Values => map.Values; - public static MarshalMap? Parse - ( - IArgumentOperation? marshalMapArgument, - IList diagnostics, - CancellationToken cancellationToken - ) + public static MarshalMap? Parse(IOperation? value) { - _ = diagnostics; - if (marshalMapArgument is null) + if (value is null) { return null; } var builder = ImmutableDictionary.CreateBuilder(); - var value = marshalMapArgument.Value; - var invalidArgumentDiagnostic = Diagnostic.Create - ( - Diagnostics.NGD1004_InvalidMarshalMapArgument, - marshalMapArgument.Syntax.GetLocation(), - marshalMapArgument.Parameter!.Name - ); - if (value is IFieldReferenceOperation fieldReference && fieldReference.Field.IsReadOnly) - { - var fieldDeclaration = fieldReference.Field.DeclaringSyntaxReferences[0].GetSyntax(cancellationToken); - var equalsValueClause = fieldDeclaration.ChildNodes().OfType() - .FirstOrDefault(); - SemanticModel? semanticModel = equalsValueClause is not null ? - fieldReference.SemanticModel!.Compilation.GetSemanticModel(equalsValueClause.SyntaxTree) : - null; - if (semanticModel?.GetOperation(equalsValueClause!, cancellationToken) is not - IFieldInitializerOperation fieldInitializer) - { - diagnostics.Add(invalidArgumentDiagnostic); - return null; - } - value = fieldInitializer.Value; - } - IObjectCreationOperation? mapCreation = value switch + var mapCreation = value switch { IConversionOperation conversion => conversion.Operand as IObjectCreationOperation, _ => value as IObjectCreationOperation @@ -66,43 +34,29 @@ CancellationToken cancellationToken var initializers = mapCreation.Initializer.Initializers; foreach (var op in initializers) { - cancellationToken.ThrowIfCancellationRequested(); ITypeOfOperation? typeOf; IOperation? marshalAs; - if (op is IInvocationOperation invocation) + if (op is not IInvocationOperation invocation) { - if (invocation.Arguments.Length != 2) - { - diagnostics.Add(invalidArgumentDiagnostic); - return null; - } - typeOf = invocation.Arguments[0].Value as ITypeOfOperation; - value = invocation.Arguments[1].Value; - marshalAs = value switch - { - IConversionOperation conversion => conversion.Operand as IObjectCreationOperation, - _ => value as IObjectCreationOperation - }; + return null; } - else + if (invocation.Arguments.Length < 2) { - diagnostics.Add(invalidArgumentDiagnostic); return null; } + typeOf = invocation.Arguments[0].Value as ITypeOfOperation; + value = invocation.Arguments[1].Value; + marshalAs = value switch + { + IConversionOperation conversion => conversion.Operand as IObjectCreationOperation, + _ => value as IObjectCreationOperation + }; if ((typeOf is null) || (marshalAs is null)) { - diagnostics.Add(invalidArgumentDiagnostic); return null; } var key = typeOf.TypeOperand.ToDisplayString(); - var marshalAsValue = DelegateMarshalling.Parser.GetMarshalAsFromOperation - ( - marshalAs, - marshalMapArgument.Parameter!.Name, - diagnostics, - diagnosticTypeSuffix: "", - cancellationToken - ); + var marshalAsValue = MarshalInfo.Parser.GetMarshalAsFromOperation(marshalAs); builder[key] = marshalAsValue; } return new(builder.ToImmutable()); diff --git a/MethodDescriptor.cs b/MethodDescriptor.cs index ca40952..bdfd7e5 100644 --- a/MethodDescriptor.cs +++ b/MethodDescriptor.cs @@ -58,38 +58,26 @@ IMethodSymbol methodSymbol InterceptorParameters = GetParameters(getInterceptorParameters: true); Name = methodSymbol.Name; Parameters = GetParameters(getInterceptorParameters: false); - hashCode = Hash.Combine(Arity, ContainingInterface, Name); + hashCode = Hash.Combine(Arity, ContainingInterface, FullName); } public override bool Equals(object? obj) => obj is MethodDescriptor other && Equals(other); public bool Equals(MethodDescriptor? other) => (other is not null) && (Arity == other.Arity) && (ContainingInterface == other.ContainingInterface) && - (Name == other.Name); + (FullName == other.FullName); public override int GetHashCode() => hashCode; private string GetParameters(bool getInterceptorParameters) { - var isAction = ContainingInterface.IsAction; - var invokeParamCount = ContainingInterface.InvokeParameterCount; - var hasMarshalParams = !isAction || (invokeParamCount != 0); - var callingConvention = $"CallingConvention callingConvention" + - (hasMarshalParams ? $", {Constants.NewLineIndent3}" : string.Empty); - var marshalMap = hasMarshalParams ? $"MarshalMap marshalMap,{Constants.NewLineIndent3}" : string.Empty; - var marshalReturnAsParam = !ContainingInterface.IsAction ? - $"MarshalAsAttribute marshalReturnAs" + - (invokeParamCount != 0 ? $",{Constants.NewLineIndent3}" : string.Empty) : - string.Empty; - var marshalParamsAsParam = invokeParamCount != 0 ? - $"MarshalAsAttribute[] marshalParamsAs" : - string.Empty; + var callingConvention = $"CallingConvention callingConvention"; var firstParamType = FirstParameterType; if (getInterceptorParameters && !IsFromFunctionPointer && (ContainingInterface.Arity > 0)) { var typeParameters = Constants.InterceptorTypeParameters[ContainingInterface.Arity]; firstParamType = $"{ContainingInterface.Category}<{typeParameters}>"; } - var firstParam = $"{firstParamType} {FirstParameterName},{Constants.NewLineIndent3}"; - return $"{firstParam}{callingConvention}{marshalMap}{marshalReturnAsParam}{marshalParamsAsParam}"; + var firstParam = $"{firstParamType} {FirstParameterName}, "; + return $"{firstParam}{callingConvention}"; } } } diff --git a/MethodReference.cs b/MethodReference.cs index cb87973..d952a14 100644 --- a/MethodReference.cs +++ b/MethodReference.cs @@ -1,8 +1,8 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Operations; using Monkeymoto.GeneratorUtils; using System; -using System.Collections.Generic; using System.Linq; using System.Threading; @@ -16,9 +16,10 @@ internal sealed class MethodReference : IEquatable public string FilePath { get; } public string InterceptorAttributeSourceText { get; } public InterfaceDescriptor Interface { get; } + public int InvocationArgumentCount { get; } public bool IsInterfaceOrMethodOpenGeneric { get; } public int Line { get; } - public DelegateMarshalling Marshalling { get; } + public MarshalInfo MarshalInfo { get; } public MethodDescriptor Method { get; } public static bool operator ==(MethodReference? left, MethodReference? right) => @@ -28,7 +29,6 @@ internal sealed class MethodReference : IEquatable public static MethodReference? GetReference ( GenericSymbolReference interfaceOrMethodReference, - IList diagnostics, CancellationToken cancellationToken ) { @@ -36,9 +36,14 @@ CancellationToken cancellationToken IMethodSymbol? methodSymbol; InvocationExpressionSyntax? invocationExpression; var node = interfaceOrMethodReference.Node; - if (node.Parent?.Parent is InvocationExpressionSyntax methodInvocationExpression) + if (node is InvocationExpressionSyntax genericMethodInvocationExpression) + { + methodSymbol = (IMethodSymbol)interfaceOrMethodReference.Symbol; + interfaceSymbol = methodSymbol.ContainingType; + invocationExpression = genericMethodInvocationExpression; + } + else if (node.Parent?.Parent is InvocationExpressionSyntax methodInvocationExpression) { - // non-generic methods (FromAction, FromFunc, FromFunctionPointer) var methodNameSyntax = ((MemberAccessExpressionSyntax)node.Parent).Name; if (methodNameSyntax.Arity != 0) { @@ -53,18 +58,27 @@ CancellationToken cancellationToken { return null; } + var semanticModel = interfaceOrMethodReference.SemanticModel!; + int invocationArgumentCount = 0; + var invocation = + semanticModel.GetOperation(invocationExpression, cancellationToken) as IInvocationOperation; + if (invocation is not null) + { + invocationArgumentCount = invocation.Arguments.Length - + invocation.Arguments.Where(static x => x.ArgumentKind != ArgumentKind.Explicit).Count(); + } var interfaceDescriptor = new InterfaceDescriptor(interfaceSymbol); var methodDescriptor = new MethodDescriptor ( interfaceDescriptor, - methodSymbol + methodSymbol! ); - var methodMarshalling = new DelegateMarshalling + var marshalInfo = MarshalInfo.GetMarshalInfo ( + methodSymbol!, + interfaceDescriptor, invocationExpression, interfaceOrMethodReference.SemanticModel!, - interfaceDescriptor, - diagnostics, cancellationToken ); return new MethodReference @@ -72,8 +86,9 @@ CancellationToken cancellationToken interfaceDescriptor, methodDescriptor, invocationExpression, - methodMarshalling, - !interfaceOrMethodReference.IsSyntaxReferenceClosedTypeOrMethod + marshalInfo, + !interfaceOrMethodReference.IsSyntaxReferenceClosedTypeOrMethod, + invocationArgumentCount ); } @@ -82,8 +97,9 @@ private MethodReference InterfaceDescriptor interfaceDescriptor, MethodDescriptor methodDescriptor, InvocationExpressionSyntax invocationExpression, - DelegateMarshalling marshalling, - bool isInterfaceOrMethodOpenGeneric + MarshalInfo marshalInfo, + bool isInterfaceOrMethodOpenGeneric, + int invocationArgumentCount ) { var methodNode = ((MemberAccessExpressionSyntax)invocationExpression.Expression).Name; @@ -93,15 +109,16 @@ bool isInterfaceOrMethodOpenGeneric Interface = interfaceDescriptor; IsInterfaceOrMethodOpenGeneric = isInterfaceOrMethodOpenGeneric; Line = linePosition.Line + 1; - Marshalling = marshalling; + MarshalInfo = marshalInfo; Method = methodDescriptor; InterceptorAttributeSourceText = $"[InterceptsLocation(@\"{FilePath}\", {Line}, {Character})]"; - hashCode = Hash.Combine(Character, FilePath, Interface, Line); + InvocationArgumentCount = invocationArgumentCount; + hashCode = Hash.Combine(Character, FilePath, Line); } public override bool Equals(object? obj) => obj is MethodReference other && Equals(other); public bool Equals(MethodReference? other) => (other is not null) && (Character == other.Character) && - (Line == other.Line) && (FilePath == other.FilePath); + (FilePath == other.FilePath) && (Line == other.Line); public override int GetHashCode() => hashCode; } } diff --git a/MethodReferenceCollection.cs b/MethodReferenceCollection.cs index 179e0d1..76d1513 100644 --- a/MethodReferenceCollection.cs +++ b/MethodReferenceCollection.cs @@ -13,8 +13,6 @@ namespace Monkeymoto.NativeGenericDelegates private readonly int hashCode; private readonly ImmutableHashSet references; - public IReadOnlyList Diagnostics { get; } - public static bool operator ==(MethodReferenceCollection left, MethodReferenceCollection right) => left.Equals(right); public static bool operator !=(MethodReferenceCollection left, MethodReferenceCollection right) => @@ -22,35 +20,24 @@ namespace Monkeymoto.NativeGenericDelegates public static IncrementalValueProvider GetReferences ( - IncrementalValueProvider interfaceOrMethodReferencesProvider + IncrementalValueProvider interfaceOrMethodReferencesProvider ) => interfaceOrMethodReferencesProvider.Select(static (interfaceOrMethodReferences, cancellationToken) => { var builder = ImmutableHashSet.CreateBuilder(); - var diagnostics = new List(); foreach (var interfaceOrMethodReference in interfaceOrMethodReferences) { - var methodReference = - MethodReference.GetReference(interfaceOrMethodReference, diagnostics, cancellationToken); + var methodReference = MethodReference.GetReference(interfaceOrMethodReference, cancellationToken); if (methodReference is not null) { _ = builder.Add(methodReference); } } - return new MethodReferenceCollection - ( - builder.ToImmutable(), - diagnostics.AsReadOnly() - ); + return new MethodReferenceCollection(builder.ToImmutable()); }); - private MethodReferenceCollection - ( - ImmutableHashSet references, - IReadOnlyList diagnostics - ) + private MethodReferenceCollection(ImmutableHashSet references) { this.references = references; - Diagnostics = diagnostics; hashCode = Hash.Combine(references); } diff --git a/OpenGenericInterceptors.cs b/OpenGenericInterceptors.cs index e81bb0f..409fd1b 100644 --- a/OpenGenericInterceptors.cs +++ b/OpenGenericInterceptors.cs @@ -45,10 +45,12 @@ private string GetSourceText() Debug.Assert(attributes.ContainsKey(kv.Key)); var first = kv.Value.First(); var typeParameters = Constants.InterceptorTypeParameters[first.Method.ContainingInterface.Arity]; - typeParameters = typeParameters.Length != 0 ? - $"<{typeParameters}>" : - typeParameters; - var interfaceName = $"{first.Method.ContainingInterface.Name}{typeParameters}"; + var interfaceName = $"{first.Method.ContainingInterface.Name}<{typeParameters}>"; + typeParameters = first.Method.Arity != 0 ? + $"<{typeParameters}, XMarshaller>" : + typeParameters.Length != 0 ? + $"<{typeParameters}>" : + typeParameters; var methodHash = kv.Key.GetHashCode(); var methodName = first.Method.Name; var parameters = first.Method.InterceptorParameters; @@ -97,7 +99,7 @@ public static {interfaceName} {methodName}{typeParameters} { firstParam = $"({method.ContainingInterface.Category}{method.ContainingInterface.TypeArgumentList})(object){firstParam}"; } - if (kv.Value[i].Marshalling.StaticCallingConvention is not null) + if (kv.Value[i].MarshalInfo.StaticCallingConvention is not null) { _ = sb.AppendLine ( diff --git a/PostInitialization.cs b/PostInitialization.cs index 45f63f3..ad5b265 100644 --- a/PostInitialization.cs +++ b/PostInitialization.cs @@ -6,48 +6,67 @@ internal static class PostInitialization { private static void BuildInterfaceDefinition(StringBuilder sb, bool isAction, int argumentCount) { - string? marshalReturnAsParameter; string? qualifiedTypeParameters; string? returnType; string? type; string? typeParameters; string? antiConstraints; + bool hasGenericMethods = true; if (isAction) { - marshalReturnAsParameter = string.Empty; returnType = "void"; type = Constants.CategoryAction; antiConstraints = Constants.Actions.AntiConstraints[argumentCount]; if (argumentCount != 0) { qualifiedTypeParameters = $"<{Constants.Actions.QualifiedTypeParameters[argumentCount]}>"; - typeParameters = $"<{Constants.Actions.TypeParameters[argumentCount]}>"; + typeParameters = $"<{Constants.Actions.TypeParameters[argumentCount]}"; } else { qualifiedTypeParameters = string.Empty; typeParameters = string.Empty; + hasGenericMethods = false; } } else { - marshalReturnAsParameter = $",{Constants.NewLineIndent3}MarshalAsAttribute? marshalReturnAs = null"; qualifiedTypeParameters = $"<{Constants.Funcs.QualifiedTypeParameters[argumentCount]}>"; returnType = "TResult"; type = Constants.CategoryFunc; - typeParameters = $"<{Constants.Funcs.TypeParameters[argumentCount]}>"; + typeParameters = $"<{Constants.Funcs.TypeParameters[argumentCount]}"; antiConstraints = Constants.Funcs.AntiConstraints[argumentCount]; } string genericType = $"{type}{typeParameters}"; + if (typeParameters.Length != 0) + { + genericType = $"{genericType}>"; + typeParameters = $"{typeParameters}>"; + } + string fullType = $"INative{type}{typeParameters}"; string parameters = Constants.Parameters[argumentCount]; string typeAsArgument = type.ToLower(); - string callingConvention = - $",{Constants.NewLineIndent3}CallingConvention callingConvention = CallingConvention.Winapi"; - string marshalParamsAsParameter = argumentCount != 0 ? - $",{Constants.NewLineIndent3}MarshalAsAttribute?[]? marshalParamsAs = null" : - string.Empty; - string marshalMap = argumentCount != 0 ? - $",{Constants.NewLineIndent3}MarshalMap? marshalMap = null" : + string callingConvention = $",{Constants.NewLineIndent3}CallingConvention callingConvention = CallingConvention.Winapi"; + string genericMethods = hasGenericMethods ? + $@" + + public static {fullType} From{type} + ( + {genericType} {typeAsArgument}{callingConvention} + ) + where TMarshaller : IMarshaller, new() + {{ + throw new NotImplementedException(); + }} + + public static {fullType} FromFunctionPointer + ( + nint functionPtr{callingConvention} + ) + where TMarshaller : IMarshaller, new() + {{ + throw new NotImplementedException(); + }}" : string.Empty; _ = sb.Append ( @@ -55,22 +74,22 @@ private static void BuildInterfaceDefinition(StringBuilder sb, bool isAction, in {{ protected object? Target {{ get; }} protected MethodInfo Method {{ get; }} - - public static INative{genericType} From{type} + + public static {fullType} From{type} ( - {genericType} {typeAsArgument}{callingConvention}{marshalMap}{marshalReturnAsParameter}{marshalParamsAsParameter} + {genericType} {typeAsArgument}{callingConvention} ) {{ throw new NotImplementedException(); }} - - public static INative{genericType} FromFunctionPointer + + public static {fullType} FromFunctionPointer ( - nint functionPtr{callingConvention}{marshalMap}{marshalReturnAsParameter}{marshalParamsAsParameter} + nint functionPtr{callingConvention} ) {{ throw new NotImplementedException(); - }} + }}{genericMethods} public nint GetFunctionPointer(); public {returnType} Invoke({parameters}); @@ -101,7 +120,15 @@ public static string GetSourceText() #nullable enable namespace {Constants.RootNamespace} -{{" +{{ + internal interface IMarshaller where TSelf : IMarshaller + {{ + protected static virtual CallingConvention? CallingConvention => null; + protected static virtual MarshalMap? MarshalMap => null; + protected static virtual MarshalAsAttribute?[]? MarshalParamsAs => null; + protected static virtual MarshalAsAttribute? MarshalReturnAs => null; + }} + " ); for (int i = 0; i < 17; ++i) { @@ -110,7 +137,7 @@ namespace {Constants.RootNamespace} } _ = source.AppendLine ( -$@" + $@" internal sealed class MarshalMap : IEnumerable> {{ public MarshalMap() {{ }}