diff --git a/AnalyzerReleases.Unshipped.md b/AnalyzerReleases.Unshipped.md index 7833da0..67febfe 100644 --- a/AnalyzerReleases.Unshipped.md +++ b/AnalyzerReleases.Unshipped.md @@ -6,4 +6,5 @@ Rule ID | Category | Severity | Notes --------|----------|----------|------- NGD1001 | Usage | Error | Diagnostics -NGD1003 | Usage | Error | Diagnostics \ No newline at end of file +NGD1003 | Usage | Error | Diagnostics +NGD1004 | Usage | Error | Diagnostics \ No newline at end of file diff --git a/DelegateMarshalling.Parser.cs b/DelegateMarshalling.Parser.cs index 23d30e2..d9f06e8 100644 --- a/DelegateMarshalling.Parser.cs +++ b/DelegateMarshalling.Parser.cs @@ -10,7 +10,7 @@ namespace Monkeymoto.NativeGenericDelegates { internal sealed partial class DelegateMarshalling { - private static class Parser + internal static class Parser { private static IReadOnlyList GetMarshalAsCollectionFromArrayInitializer ( @@ -185,7 +185,7 @@ CancellationToken cancellationToken return null; } - private static string? GetMarshalAsFromOperation + public static string? GetMarshalAsFromOperation ( IOperation value, string parameterName, diff --git a/DelegateMarshalling.cs b/DelegateMarshalling.cs index 0274f2f..c0d42a1 100644 --- a/DelegateMarshalling.cs +++ b/DelegateMarshalling.cs @@ -32,6 +32,7 @@ CancellationToken cancellationToken ) { IArgumentOperation? callingConventionArgument = null; + IArgumentOperation? marshalMapArgument = null; IArgumentOperation? marshalReturnAsArgument = null; IArgumentOperation? marshalParamsAsArgument = null; foreach (var argumentNode in invocationExpression.ArgumentList.Arguments) @@ -42,6 +43,16 @@ CancellationToken cancellationToken case "callingConvention": callingConventionArgument = argumentOp; break; + case "marshalMap": + IConversionOperation? conversion = argumentOp.Value as IConversionOperation; + ILiteralOperation? literal = conversion?.Operand as ILiteralOperation; + if (!argumentOp.ConstantValue.HasValue && + !(conversion?.ConstantValue.HasValue ?? false) && + !(literal?.ConstantValue.HasValue ?? false)) + { + marshalMapArgument = argumentOp; + } + break; case "marshalReturnAs": marshalReturnAsArgument = argumentOp; break; @@ -60,6 +71,38 @@ CancellationToken cancellationToken cancellationToken ); MarshalReturnAs = Parser.GetMarshalReturnAs(marshalReturnAsArgument, diagnostics, cancellationToken); + var marshalMap = MarshalMap.Parse(marshalMapArgument, diagnostics, cancellationToken); + if (marshalMap is not null) + { + var marshalParamsList = new List(interfaceDescriptor.InvokeParameterCount); + if (MarshalParamsAs is not null) + { + marshalParamsList.AddRange(MarshalParamsAs); + } + while (marshalParamsList.Count < interfaceDescriptor.InvokeParameterCount) + { + marshalParamsList.Add(null); + } + bool dirty = false; + for (int i = 0; i < interfaceDescriptor.InvokeParameterCount; ++i) + { + if ((marshalParamsList[i] is null) && + marshalMap.TryGetValue(interfaceDescriptor.TypeArguments[i], out var marshalParamAs)) + { + marshalParamsList[i] = marshalParamAs; + dirty = true; + } + } + if (dirty) + { + MarshalParamsAs = marshalParamsList.AsReadOnly(); + } + if ((MarshalReturnAs is null) && !interfaceDescriptor.IsAction && + marshalMap.TryGetValue(interfaceDescriptor.TypeArguments.Last(), out var marshalReturnAs)) + { + MarshalReturnAs = marshalReturnAs; + } + } if (callingConventionArgument is not null) { var value = callingConventionArgument.Value; diff --git a/Diagnostics.cs b/Diagnostics.cs index 8fd54de..5caad6a 100644 --- a/Diagnostics.cs +++ b/Diagnostics.cs @@ -24,5 +24,15 @@ internal static class Diagnostics DiagnosticSeverity.Error, true ); + + public static readonly DiagnosticDescriptor NGD1004_InvalidMarshalMapArgument = new + ( + "NGD1004", + "NGD1004: Invalid MarshalMap argument", + "MarshalMap argument {0} must be null or use collection initializer syntax", + "Usage", + DiagnosticSeverity.Error, + true + ); } } diff --git a/ImplementationClassCollection.Key.cs b/ImplementationClassCollection.Key.cs index ed6b4c3..b6a3973 100644 --- a/ImplementationClassCollection.Key.cs +++ b/ImplementationClassCollection.Key.cs @@ -1,4 +1,5 @@ using System; +using System.Diagnostics; namespace Monkeymoto.NativeGenericDelegates { @@ -20,11 +21,26 @@ public Key(MethodReference methodReference) { MethodReference = methodReference; hashCode = Hash.Combine(MethodReference.Method, MethodReference.Marshalling); + Debug.WriteLine($"hashCode: {hashCode}"); } public override bool Equals(object? obj) => obj is Key other && Equals(other); public bool Equals(Key other) => (MethodReference.Method == other.MethodReference.Method) && (MethodReference.Marshalling == other.MethodReference.Marshalling); + //public bool Equals(Key other) + //{ + // Debug.WriteLine($"interface: {MethodReference.Method.ContainingInterface.FullName} / {other.MethodReference.Method.ContainingInterface.FullName} / equal? {MethodReference.Method.ContainingInterface == other.MethodReference.Method.ContainingInterface}"); + // Debug.WriteLine($"method: {MethodReference.Method.Name} / {other.MethodReference.Method.Name} / equal? {MethodReference.Method == other.MethodReference.Method}"); + // Debug.WriteLine($"arity: {MethodReference.Method.Arity} / {other.MethodReference.Method.Arity}"); + // Debug.WriteLine($"aritys equal? {MethodReference.Method.Arity == other.MethodReference.Method.Arity}"); + // Debug.WriteLine($"interfaces equal? {MethodReference.Method.ContainingInterface == other.MethodReference.Method.ContainingInterface}"); + // Debug.WriteLine($"names equal? {MethodReference.Method.Name == other.MethodReference.Method.Name}"); + // Debug.WriteLine($"methods equal? {MethodReference.Method == other.MethodReference.Method}"); + // //public bool Equals(MethodDescriptor? other) => + // // (other is not null) && (Arity == other.Arity) && (ContainingInterface == other.ContainingInterface) && + // // (Name == other.Name); + // return (MethodReference.Method == other.MethodReference.Method) && (MethodReference.Marshalling == other.MethodReference.Marshalling); + //} public override int GetHashCode() => hashCode; } } diff --git a/MarshalMap.cs b/MarshalMap.cs new file mode 100644 index 0000000..5a9c550 --- /dev/null +++ b/MarshalMap.cs @@ -0,0 +1,121 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Operations; +using System.Collections; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Threading; + +namespace Monkeymoto.NativeGenericDelegates +{ + internal sealed class MarshalMap : IReadOnlyDictionary + { + private readonly ImmutableDictionary map; + + public string? this[string key] => map[key]; + public int Count => map.Count; + public IEnumerable Keys => map.Keys; + public IEnumerable Values => map.Values; + + public static MarshalMap? Parse + ( + IArgumentOperation? marshalMapArgument, + IList diagnostics, + CancellationToken cancellationToken + ) + { + _ = diagnostics; + if (marshalMapArgument is null) + { + return null; + } + var builder = ImmutableDictionary.CreateBuilder(); + var value = marshalMapArgument.Value; + var invalidArgumentDiagnostic = Diagnostic.Create + ( + Diagnostics.NGD1004_InvalidMarshalMapArgument, + marshalMapArgument.Syntax.GetLocation(), + marshalMapArgument.Parameter!.Name + ); + if (value is IFieldReferenceOperation fieldReference && fieldReference.Field.IsReadOnly) + { + var fieldDeclaration = fieldReference.Field.DeclaringSyntaxReferences[0].GetSyntax(cancellationToken); + var equalsValueClause = fieldDeclaration.ChildNodes().OfType() + .FirstOrDefault(); + SemanticModel? semanticModel = equalsValueClause is not null ? + fieldReference.SemanticModel!.Compilation.GetSemanticModel(equalsValueClause.SyntaxTree) : + null; + if (semanticModel?.GetOperation(equalsValueClause!, cancellationToken) is not + IFieldInitializerOperation fieldInitializer) + { + diagnostics.Add(invalidArgumentDiagnostic); + return null; + } + value = fieldInitializer.Value; + } + IObjectCreationOperation? mapCreation = value switch + { + IConversionOperation conversion => conversion.Operand as IObjectCreationOperation, + _ => value as IObjectCreationOperation + }; + if ((mapCreation?.Initializer is null) || (mapCreation.Initializer.Initializers.Length == 0)) + { + return new(builder.ToImmutable()); + } + var initializers = mapCreation.Initializer.Initializers; + foreach (var op in initializers) + { + cancellationToken.ThrowIfCancellationRequested(); + ITypeOfOperation? typeOf; + IOperation? marshalAs; + if (op is IInvocationOperation invocation) + { + if (invocation.Arguments.Length != 2) + { + diagnostics.Add(invalidArgumentDiagnostic); + return null; + } + typeOf = invocation.Arguments[0].Value as ITypeOfOperation; + value = invocation.Arguments[1].Value; + marshalAs = value switch + { + IConversionOperation conversion => conversion.Operand as IObjectCreationOperation, + _ => value as IObjectCreationOperation + }; + } + else + { + diagnostics.Add(invalidArgumentDiagnostic); + return null; + } + if ((typeOf is null) || (marshalAs is null)) + { + diagnostics.Add(invalidArgumentDiagnostic); + return null; + } + var key = typeOf.TypeOperand.ToDisplayString(); + var marshalAsValue = DelegateMarshalling.Parser.GetMarshalAsFromOperation + ( + marshalAs, + marshalMapArgument.Parameter!.Name, + diagnostics, + diagnosticTypeSuffix: "", + cancellationToken + ); + builder[key] = marshalAsValue; + } + return new(builder.ToImmutable()); + } + + private MarshalMap(ImmutableDictionary dictionary) + { + map = dictionary; + } + + public bool ContainsKey(string key) => map.ContainsKey(key); + public IEnumerator> GetEnumerator() => map.GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => map.GetEnumerator(); + public bool TryGetValue(string key, out string? value) => map.TryGetValue(key, out value); + } +} diff --git a/MethodDescriptor.cs b/MethodDescriptor.cs index b6c5461..b31d7b7 100644 --- a/MethodDescriptor.cs +++ b/MethodDescriptor.cs @@ -69,6 +69,7 @@ public bool Equals(MethodDescriptor? other) => private string GetParameters(bool getInterceptorParameters) { + var marshalMap = $"MarshalMap marshalMap,{Constants.NewLineIndent3}"; var marshalReturnAsParam = !ContainingInterface.IsAction ? $"MarshalAsAttribute marshalReturnAs,{Constants.NewLineIndent3}" : string.Empty; @@ -82,8 +83,8 @@ private string GetParameters(bool getInterceptorParameters) firstParameterType = $"{ContainingInterface.Category}<{typeParameters}>"; } return - $"{firstParameterType} {FirstParameterName},{Constants.NewLineIndent3}{marshalReturnAsParam}" + - $"{marshalParamsAsParam}CallingConvention callingConvention"; + $"{firstParameterType} {FirstParameterName},{Constants.NewLineIndent3}{marshalMap}" + + $"{marshalReturnAsParam}{marshalParamsAsParam}CallingConvention callingConvention"; } } } diff --git a/PostInitialization.cs b/PostInitialization.cs index 51451bd..1740a93 100644 --- a/PostInitialization.cs +++ b/PostInitialization.cs @@ -43,13 +43,16 @@ private static void BuildInterfaceDefinition(StringBuilder sb, bool isAction, in string marshalParamsAsParameter = argumentCount != 0 ? $",{Constants.NewLineIndent3}MarshalAsAttribute?[]? marshalParamsAs = null" : string.Empty; + string marshalMap = argumentCount != 0 ? + $",{Constants.NewLineIndent3}MarshalMap? marshalMap = null" : + string.Empty; _ = sb.Append ( $@" internal interface INative{type}{qualifiedTypeParameters} {{ public static INative{genericType} From{type} ( - {genericType} {typeAsArgument}{marshalReturnAsParameter}{marshalParamsAsParameter}{callingConvention} + {genericType} {typeAsArgument}{marshalMap}{marshalReturnAsParameter}{marshalParamsAsParameter}{callingConvention} ) {{ throw new NotImplementedException(); @@ -57,7 +60,7 @@ public static INative{genericType} From{type} public static INative{genericType} FromFunctionPointer ( - nint functionPtr{marshalReturnAsParameter}{marshalParamsAsParameter}{callingConvention} + nint functionPtr{marshalMap}{marshalReturnAsParameter}{marshalParamsAsParameter}{callingConvention} ) {{ throw new NotImplementedException(); @@ -77,6 +80,8 @@ public static string GetSourceText() ( $@"// using System; +using System.Collections; +using System.Collections.Generic; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; @@ -93,6 +98,15 @@ namespace {Constants.RootNamespace} _ = source.AppendLine ( $@" + internal sealed class MarshalMap : IEnumerable> + {{ + public MarshalMap() {{ }} + public void Add(Type key, MarshalAsAttribute value) {{ }} + IEnumerator> + IEnumerable>.GetEnumerator() => + throw new NotImplementedException(); + IEnumerator IEnumerable.GetEnumerator() => throw new NotImplementedException(); + }} }} namespace System.Runtime.CompilerServices