diff --git a/AnalyzerReleases.Unshipped.md b/AnalyzerReleases.Unshipped.md index 4b5931d..7833da0 100644 --- a/AnalyzerReleases.Unshipped.md +++ b/AnalyzerReleases.Unshipped.md @@ -6,5 +6,4 @@ Rule ID | Category | Severity | Notes --------|----------|----------|------- NGD1001 | Usage | Error | Diagnostics -NGD1002 | Usage | Error | Diagnostics NGD1003 | Usage | Error | Diagnostics \ No newline at end of file diff --git a/ClosedGenericInterceptor.cs b/ClosedGenericInterceptor.cs index 377477e..f82c387 100644 --- a/ClosedGenericInterceptor.cs +++ b/ClosedGenericInterceptor.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; +using System.Runtime.InteropServices; using System.Text; namespace Monkeymoto.NativeGenericDelegates @@ -43,22 +44,44 @@ private string GetSourceText() var sb = new StringBuilder($"{Constants.NewLineIndent2}"); foreach (var reference in InterceptedMethodReferences.Select(x => x.MethodReference)) { - _ = sb.Append($"{Constants.NewLineIndent2}").Append(reference.InterceptorAttributeSourceText); + _ = sb.Append(reference.InterceptorAttributeSourceText).Append($"{Constants.NewLineIndent2}"); } var method = InterceptsMethod; var typeParameters = GetTypeParameters(method.ContainingInterface.Arity, method.Arity); _ = sb.Append ( - $@" - [MethodImpl(MethodImplOptions.AggressiveInlining)] + $@"[MethodImpl(MethodImplOptions.AggressiveInlining)] public static {method.ContainingInterface.FullName} {method.Name}{typeParameters} ( {method.Parameters} ) - {{ + {{" + ); + if (ImplementationClass.Marshalling.StaticCallingConvention is not null) + { + _ = sb.Append + ( + $@" return new {ImplementationClass.ClassName}({method.FirstParameterName}); }}" - ); + ); + } + else + { + _ = sb.Append + ( + $@" + return callingConvention switch + {{ + CallingConvention.Cdecl => ({method.ContainingInterface.FullName})new {ImplementationClass.ClassName}_{nameof(CallingConvention.Cdecl)}({method.FirstParameterName}), + CallingConvention.StdCall => new {ImplementationClass.ClassName}_{nameof(CallingConvention.StdCall)}({method.FirstParameterName}), + CallingConvention.ThisCall => new {ImplementationClass.ClassName}_{nameof(CallingConvention.ThisCall)}({method.FirstParameterName}), + CallingConvention.Winapi => new {ImplementationClass.ClassName}_{nameof(CallingConvention.Winapi)}({method.FirstParameterName}), + _ => throw new NotImplementedException() + }}; + }}" + ); + } return sb.ToString(); } internal static string GetTypeParameters(int interfaceArity, int methodArity) diff --git a/DelegateMarshalling.cs b/DelegateMarshalling.cs index dd8c9cf..0274f2f 100644 --- a/DelegateMarshalling.cs +++ b/DelegateMarshalling.cs @@ -13,9 +13,10 @@ internal sealed partial class DelegateMarshalling : IEquatable? MarshalParamsAs { get; } public string? MarshalReturnAs { get; } + public string? RuntimeCallingConvention { get; } + public CallingConvention? StaticCallingConvention { get; } public static bool operator ==(DelegateMarshalling? left, DelegateMarshalling? right) => left?.Equals(right) ?? right is null; @@ -51,45 +52,45 @@ CancellationToken cancellationToken break; } } - CallingConvention callingConvention = CallingConvention.Winapi; + MarshalParamsAs = Parser.GetMarshalParamsAs + ( + marshalParamsAsArgument, + interfaceDescriptor.InvokeParameterCount, + diagnostics, + cancellationToken + ); + MarshalReturnAs = Parser.GetMarshalReturnAs(marshalReturnAsArgument, diagnostics, cancellationToken); if (callingConventionArgument is not null) { - bool isValid = false; - var field = (callingConventionArgument.Value as IFieldReferenceOperation)?.Field; + var value = callingConventionArgument.Value; + var field = (value as IFieldReferenceOperation)?.Field; if (field is not null && SymbolEqualityComparer.Default.Equals(field.ContainingType, field.Type) && - Enum.TryParse(field?.Name, false, out callingConvention)) + Enum.TryParse(field?.Name, false, out CallingConvention callingConvention)) { - isValid = true; + RuntimeCallingConvention = null; + StaticCallingConvention = callingConvention; } - if (!isValid) + else { - callingConvention = CallingConvention.Winapi; - diagnostics.Add - ( - Diagnostic.Create - ( - Diagnostics.NGD1002_InvalidCallingConventionArgument, - callingConventionArgument.Syntax.GetLocation() - ) - ); + RuntimeCallingConvention = value.ToString(); + StaticCallingConvention = null; } } - CallingConvention = callingConvention; - MarshalParamsAs = Parser.GetMarshalParamsAs + hashCode = Hash.Combine ( - marshalParamsAsArgument, - interfaceDescriptor.InvokeParameterCount, - diagnostics, - cancellationToken + MarshalParamsAs, + MarshalReturnAs, + RuntimeCallingConvention, + StaticCallingConvention ); - MarshalReturnAs = Parser.GetMarshalReturnAs(marshalReturnAsArgument, diagnostics, cancellationToken); - hashCode = Hash.Combine(CallingConvention, MarshalParamsAs, MarshalReturnAs); } public override bool Equals(object? obj) => obj is DelegateMarshalling other && Equals(other); public bool Equals(DelegateMarshalling? other) => - (other is not null) && (CallingConvention == other.CallingConvention) && - MarshalParamsAs.SequenceEqual(other.MarshalParamsAs) && (MarshalReturnAs == other.MarshalReturnAs); + (other is not null) && MarshalParamsAs.SequenceEqual(other.MarshalParamsAs) && + (MarshalReturnAs == other.MarshalReturnAs) && + (RuntimeCallingConvention == other.RuntimeCallingConvention) && + (StaticCallingConvention == other.StaticCallingConvention); public override int GetHashCode() => hashCode; } } diff --git a/Diagnostics.cs b/Diagnostics.cs index d95d912..8fd54de 100644 --- a/Diagnostics.cs +++ b/Diagnostics.cs @@ -15,16 +15,6 @@ internal static class Diagnostics true ); - public static readonly DiagnosticDescriptor NGD1002_InvalidCallingConventionArgument = new - ( - "NGD1002", - "NGD1002: Invalid CallingConvention argument", - "CallingConvention argument must be literal or static readonly field", - "Usage", - DiagnosticSeverity.Error, - true - ); - public static readonly DiagnosticDescriptor NGD1003_MarshalAsArgumentSpreadElementNotSupported = new ( "NGD1003", diff --git a/ImplementationClass.cs b/ImplementationClass.cs index 2b4724c..0c437db 100644 --- a/ImplementationClass.cs +++ b/ImplementationClass.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Runtime.InteropServices; using System.Text; namespace Monkeymoto.NativeGenericDelegates @@ -46,11 +47,11 @@ IReadOnlyList methodReferences public bool Equals(ImplementationClass? other) => (other is not null) && (SourceText == other.SourceText); public override int GetHashCode() => hashCode; - private string GetFromDelegateConstructor() + private string GetFromDelegateConstructor(string classSuffix) { var firstParam = Method.FirstParameterName; return - $@"internal {ClassName}({Method.FirstParameterType} {firstParam}) + $@"internal {ClassName}{classSuffix}({Method.FirstParameterType} {firstParam}) {{ ArgumentNullException.ThrowIfNull({firstParam}); handler = (Handler)Delegate.CreateDelegate(typeof(Handler), {firstParam}.Target, {firstParam}.Method); @@ -58,10 +59,10 @@ private string GetFromDelegateConstructor() }}"; } - private string GetFromFunctionPointerConstructor() + private string GetFromFunctionPointerConstructor(string classSuffix) { return - $@"internal {ClassName}(nint functionPtr) + $@"internal {ClassName}{classSuffix}(nint functionPtr) {{ if (functionPtr == nint.Zero) {{ @@ -104,10 +105,33 @@ private string GetInvokeParameters() private string GetSourceText() { - var callingConvention = Marshalling.CallingConvention; + if (Marshalling.StaticCallingConvention is not null) + { + return + GetSourceText(Marshalling.StaticCallingConvention.Value); + } + var interceptor = Interceptor?.SourceText; + if (interceptor is not null) + { + interceptor = + $@" + file static class {ClassName} + {{{interceptor} + }}"; + } + return + $@"{GetSourceText(CallingConvention.Cdecl, $"_{nameof(CallingConvention.Cdecl)}")} + {GetSourceText(CallingConvention.StdCall, $"_{nameof(CallingConvention.StdCall)}")} + {GetSourceText(CallingConvention.ThisCall, $"_{nameof(CallingConvention.ThisCall)}")} + {GetSourceText(CallingConvention.Winapi, $"_{nameof(CallingConvention.Winapi)}")}{interceptor ?? string.Empty}"; + } + + private string GetSourceText(CallingConvention callingConvention, string? classSuffix = null) + { + classSuffix ??= string.Empty; var constructor = Method.IsFromFunctionPointer ? - GetFromFunctionPointerConstructor() : - GetFromDelegateConstructor(); + GetFromFunctionPointerConstructor(classSuffix) : + GetFromDelegateConstructor(classSuffix); var interfaceFullName = Method.ContainingInterface.FullName; var invokeParameterCount = Method.ContainingInterface.InvokeParameterCount; var invokeParameters = GetInvokeParameters(); @@ -127,8 +151,15 @@ private string GetSourceText() returnType = Method.ContainingInterface.TypeArguments.Last(); break; } + var interceptor = Marshalling.StaticCallingConvention is not null ? + Interceptor?.SourceText ?? string.Empty : + string.Empty; + if (interceptor != string.Empty) + { + interceptor = $"{Constants.NewLineIndent2}{interceptor}"; + } return - $@"file sealed class {ClassName} : {interfaceFullName} + $@"file sealed class {ClassName}{classSuffix} : {interfaceFullName} {{ private readonly Handler handler; private readonly nint functionPtr; @@ -145,7 +176,7 @@ private string GetSourceText() {returnMarshalAsAttribute}public {returnType} Invoke{invokeParameters} {{ {returnKeyword}handler({Constants.Arguments[invokeParameterCount]}); - }}{Interceptor?.SourceText ?? string.Empty} + }}{interceptor} }} "; }