Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion AnalyzerReleases.Unshipped.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,4 @@
Rule ID | Category | Severity | Notes
--------|----------|----------|-------
NGD1001 | Usage | Error | Diagnostics
NGD1002 | Usage | Error | Diagnostics
NGD1003 | Usage | Error | Diagnostics
33 changes: 28 additions & 5 deletions ClosedGenericInterceptor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;

namespace Monkeymoto.NativeGenericDelegates
Expand Down Expand Up @@ -43,22 +44,44 @@ private string GetSourceText()
var sb = new StringBuilder($"{Constants.NewLineIndent2}");
foreach (var reference in InterceptedMethodReferences.Select(x => x.MethodReference))
{
_ = sb.Append($"{Constants.NewLineIndent2}").Append(reference.InterceptorAttributeSourceText);
_ = sb.Append(reference.InterceptorAttributeSourceText).Append($"{Constants.NewLineIndent2}");
}
var method = InterceptsMethod;
var typeParameters = GetTypeParameters(method.ContainingInterface.Arity, method.Arity);
_ = sb.Append
(
$@"
[MethodImpl(MethodImplOptions.AggressiveInlining)]
$@"[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static {method.ContainingInterface.FullName} {method.Name}{typeParameters}
(
{method.Parameters}
)
{{
{{"
);
if (ImplementationClass.Marshalling.StaticCallingConvention is not null)
{
_ = sb.Append
(
$@"
return new {ImplementationClass.ClassName}({method.FirstParameterName});
}}"
);
);
}
else
{
_ = sb.Append
(
$@"
return callingConvention switch
{{
CallingConvention.Cdecl => ({method.ContainingInterface.FullName})new {ImplementationClass.ClassName}_{nameof(CallingConvention.Cdecl)}({method.FirstParameterName}),
CallingConvention.StdCall => new {ImplementationClass.ClassName}_{nameof(CallingConvention.StdCall)}({method.FirstParameterName}),
CallingConvention.ThisCall => new {ImplementationClass.ClassName}_{nameof(CallingConvention.ThisCall)}({method.FirstParameterName}),
CallingConvention.Winapi => new {ImplementationClass.ClassName}_{nameof(CallingConvention.Winapi)}({method.FirstParameterName}),
_ => throw new NotImplementedException()
}};
}}"
);
}
return sb.ToString();
}
internal static string GetTypeParameters(int interfaceArity, int methodArity)
Expand Down
53 changes: 27 additions & 26 deletions DelegateMarshalling.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ internal sealed partial class DelegateMarshalling : IEquatable<DelegateMarshalli
{
private readonly int hashCode;

public CallingConvention CallingConvention { get; }
public IReadOnlyList<string?>? MarshalParamsAs { get; }
public string? MarshalReturnAs { get; }
public string? RuntimeCallingConvention { get; }
public CallingConvention? StaticCallingConvention { get; }

public static bool operator ==(DelegateMarshalling? left, DelegateMarshalling? right) =>
left?.Equals(right) ?? right is null;
Expand Down Expand Up @@ -51,45 +52,45 @@ CancellationToken cancellationToken
break;
}
}
CallingConvention callingConvention = CallingConvention.Winapi;
MarshalParamsAs = Parser.GetMarshalParamsAs
(
marshalParamsAsArgument,
interfaceDescriptor.InvokeParameterCount,
diagnostics,
cancellationToken
);
MarshalReturnAs = Parser.GetMarshalReturnAs(marshalReturnAsArgument, diagnostics, cancellationToken);
if (callingConventionArgument is not null)
{
bool isValid = false;
var field = (callingConventionArgument.Value as IFieldReferenceOperation)?.Field;
var value = callingConventionArgument.Value;
var field = (value as IFieldReferenceOperation)?.Field;
if (field is not null && SymbolEqualityComparer.Default.Equals(field.ContainingType, field.Type) &&
Enum.TryParse(field?.Name, false, out callingConvention))
Enum.TryParse(field?.Name, false, out CallingConvention callingConvention))
{
isValid = true;
RuntimeCallingConvention = null;
StaticCallingConvention = callingConvention;
}
if (!isValid)
else
{
callingConvention = CallingConvention.Winapi;
diagnostics.Add
(
Diagnostic.Create
(
Diagnostics.NGD1002_InvalidCallingConventionArgument,
callingConventionArgument.Syntax.GetLocation()
)
);
RuntimeCallingConvention = value.ToString();
StaticCallingConvention = null;
}
}
CallingConvention = callingConvention;
MarshalParamsAs = Parser.GetMarshalParamsAs
hashCode = Hash.Combine
(
marshalParamsAsArgument,
interfaceDescriptor.InvokeParameterCount,
diagnostics,
cancellationToken
MarshalParamsAs,
MarshalReturnAs,
RuntimeCallingConvention,
StaticCallingConvention
);
MarshalReturnAs = Parser.GetMarshalReturnAs(marshalReturnAsArgument, diagnostics, cancellationToken);
hashCode = Hash.Combine(CallingConvention, MarshalParamsAs, MarshalReturnAs);
}

public override bool Equals(object? obj) => obj is DelegateMarshalling other && Equals(other);
public bool Equals(DelegateMarshalling? other) =>
(other is not null) && (CallingConvention == other.CallingConvention) &&
MarshalParamsAs.SequenceEqual(other.MarshalParamsAs) && (MarshalReturnAs == other.MarshalReturnAs);
(other is not null) && MarshalParamsAs.SequenceEqual(other.MarshalParamsAs) &&
(MarshalReturnAs == other.MarshalReturnAs) &&
(RuntimeCallingConvention == other.RuntimeCallingConvention) &&
(StaticCallingConvention == other.StaticCallingConvention);
public override int GetHashCode() => hashCode;
}
}
10 changes: 0 additions & 10 deletions Diagnostics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,6 @@ internal static class Diagnostics
true
);

public static readonly DiagnosticDescriptor NGD1002_InvalidCallingConventionArgument = new
(
"NGD1002",
"NGD1002: Invalid CallingConvention argument",
"CallingConvention argument must be literal or static readonly field",
"Usage",
DiagnosticSeverity.Error,
true
);

public static readonly DiagnosticDescriptor NGD1003_MarshalAsArgumentSpreadElementNotSupported = new
(
"NGD1003",
Expand Down
49 changes: 40 additions & 9 deletions ImplementationClass.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;

namespace Monkeymoto.NativeGenericDelegates
Expand Down Expand Up @@ -46,22 +47,22 @@ IReadOnlyList<MethodReference> methodReferences
public bool Equals(ImplementationClass? other) => (other is not null) && (SourceText == other.SourceText);
public override int GetHashCode() => hashCode;

private string GetFromDelegateConstructor()
private string GetFromDelegateConstructor(string classSuffix)
{
var firstParam = Method.FirstParameterName;
return
$@"internal {ClassName}({Method.FirstParameterType} {firstParam})
$@"internal {ClassName}{classSuffix}({Method.FirstParameterType} {firstParam})
{{
ArgumentNullException.ThrowIfNull({firstParam});
handler = (Handler)Delegate.CreateDelegate(typeof(Handler), {firstParam}.Target, {firstParam}.Method);
functionPtr = Marshal.GetFunctionPointerForDelegate(handler);
}}";
}

private string GetFromFunctionPointerConstructor()
private string GetFromFunctionPointerConstructor(string classSuffix)
{
return
$@"internal {ClassName}(nint functionPtr)
$@"internal {ClassName}{classSuffix}(nint functionPtr)
{{
if (functionPtr == nint.Zero)
{{
Expand Down Expand Up @@ -104,10 +105,33 @@ private string GetInvokeParameters()

private string GetSourceText()
{
var callingConvention = Marshalling.CallingConvention;
if (Marshalling.StaticCallingConvention is not null)
{
return
GetSourceText(Marshalling.StaticCallingConvention.Value);
}
var interceptor = Interceptor?.SourceText;
if (interceptor is not null)
{
interceptor =
$@"
file static class {ClassName}
{{{interceptor}
}}";
}
return
$@"{GetSourceText(CallingConvention.Cdecl, $"_{nameof(CallingConvention.Cdecl)}")}
{GetSourceText(CallingConvention.StdCall, $"_{nameof(CallingConvention.StdCall)}")}
{GetSourceText(CallingConvention.ThisCall, $"_{nameof(CallingConvention.ThisCall)}")}
{GetSourceText(CallingConvention.Winapi, $"_{nameof(CallingConvention.Winapi)}")}{interceptor ?? string.Empty}";
}

private string GetSourceText(CallingConvention callingConvention, string? classSuffix = null)
{
classSuffix ??= string.Empty;
var constructor = Method.IsFromFunctionPointer ?
GetFromFunctionPointerConstructor() :
GetFromDelegateConstructor();
GetFromFunctionPointerConstructor(classSuffix) :
GetFromDelegateConstructor(classSuffix);
var interfaceFullName = Method.ContainingInterface.FullName;
var invokeParameterCount = Method.ContainingInterface.InvokeParameterCount;
var invokeParameters = GetInvokeParameters();
Expand All @@ -127,8 +151,15 @@ private string GetSourceText()
returnType = Method.ContainingInterface.TypeArguments.Last();
break;
}
var interceptor = Marshalling.StaticCallingConvention is not null ?
Interceptor?.SourceText ?? string.Empty :
string.Empty;
if (interceptor != string.Empty)
{
interceptor = $"{Constants.NewLineIndent2}{interceptor}";
}
return
$@"file sealed class {ClassName} : {interfaceFullName}
$@"file sealed class {ClassName}{classSuffix} : {interfaceFullName}
{{
private readonly Handler handler;
private readonly nint functionPtr;
Expand All @@ -145,7 +176,7 @@ private string GetSourceText()
{returnMarshalAsAttribute}public {returnType} Invoke{invokeParameters}
{{
{returnKeyword}handler({Constants.Arguments[invokeParameterCount]});
}}{Interceptor?.SourceText ?? string.Empty}
}}{interceptor}
}}
";
}
Expand Down