diff --git a/Constants.cs b/Constants.cs index 8753851..4b560c7 100644 --- a/Constants.cs +++ b/Constants.cs @@ -42,11 +42,8 @@ internal static class Constants public const string SourceFileName = RootNamespace + ".g.cs"; public const string MarshalAsArgumentMustUseObjectCreationSyntaxID = "NGD1001"; - public const string InvalidMarshalParamsAsArrayLengthID = "NGD1002"; public static readonly DiagnosticDescriptor MarshalAsArgumentMustUseObjectCreationSyntaxDescriptor = new(MarshalAsArgumentMustUseObjectCreationSyntaxID, "Invalid MarshalAs argument", "MarshalAs argument must be null or use object creation syntax", "Usage", DiagnosticSeverity.Error, true); - public static readonly DiagnosticDescriptor InvalidMarshalParamsAsArrayLengthDescriptor = - new(InvalidMarshalParamsAsArrayLengthID, $"Invalid marshalParamsAs argument", $"marshalParamsAs argument must be array of correct length", "Usage", DiagnosticSeverity.Error, true); public static readonly string[] GenericActionTypeParameters = new[] { diff --git a/MarshalInfo.cs b/MarshalInfo.cs index 83899db..8c85c48 100644 --- a/MarshalInfo.cs +++ b/MarshalInfo.cs @@ -25,20 +25,23 @@ internal static class MarshalInfo List marshalAsParamsStrings = new(); if (collection is IArrayCreationOperation arrayCreation) { - var arrayLength = arrayCreation.DimensionSizes[0].ConstantValue; - if (!arrayLength.HasValue || ((int)arrayLength.Value!) != argumentCount) - { - diagnostics.Add(Diagnostic.Create(Constants.InvalidMarshalParamsAsArrayLengthDescriptor, location)); - } - else if (arrayCreation.Initializer is not null) + if (arrayCreation.Initializer is not null && (argumentCount > 0)) { foreach (var elementValue in arrayCreation.Initializer.ElementValues) { cancellationToken.ThrowIfCancellationRequested(); GetMarshalAsFromOperation(elementValue, cancellationToken, argumentCount, diagnostics, location, marshalAsParamsStrings); + if (marshalAsParamsStrings.Count == argumentCount) + { + break; + } + } + for (int count = marshalAsParamsStrings.Count; count < argumentCount; ++count) + { + marshalAsParamsStrings.Add(null); } } - // else (no initializer), default to no marshaling + // else (no initializer or no arguments), default to no marshaling } else if (!collection.ConstantValue.HasValue) // argument is not null { @@ -54,7 +57,7 @@ internal static class MarshalInfo return marshalAsParamsStrings.Count > 0 ? marshalAsParamsStrings.ToImmutableArray() : null; } - private static void GetMarshalAsFromField(IFieldReferenceOperation fieldReference, CancellationToken cancellationToken, int argumentCount, List diagnostics, Location location, List marshalAsStrings) + private static void GetMarshalAsFromField(IFieldReferenceOperation fieldReference, CancellationToken cancellationToken, int argumentCount, List marshalAsStrings) { // `GetOperation` is only returning `null` for the relevant `SyntaxNode`s here, so we have to manually parse the field initializer // see @@ -64,22 +67,20 @@ private static void GetMarshalAsFromField(IFieldReferenceOperation fieldReferenc bool isInsideArrayInitializer = false; bool isInsideNewExpression = false; bool isInsideObjectInitializer = false; - bool addedArrayLengthDiagnostic = false; var addMarshalAsString = () => { if (sb.Length != 0) { marshalAsStrings.Add(sb.ToString()); - if (isArray && !addedArrayLengthDiagnostic && marshalAsStrings.Count > argumentCount) - { - addedArrayLengthDiagnostic = true; - diagnostics.Add(Diagnostic.Create(Constants.InvalidMarshalParamsAsArrayLengthDescriptor, location)); - } _ = sb.Clear(); } }; foreach (var syntaxToken in fieldDeclaration.DescendantTokens()) { + if (marshalAsStrings.Count == argumentCount) + { + return; + } var token = syntaxToken.ToString(); switch (token) { @@ -112,7 +113,15 @@ private static void GetMarshalAsFromField(IFieldReferenceOperation fieldReferenc addMarshalAsString(); continue; case "null": + if (isArray && !isInsideArrayInitializer) + { + return; + } marshalAsStrings.Add(null); + if (!isArray && !isInsideObjectInitializer) + { + return; + } continue; case ",": if (isInsideObjectInitializer) @@ -158,7 +167,14 @@ private static void GetMarshalAsFromOperation(IOperation value, CancellationToke } if (value is IFieldReferenceOperation fieldReference && fieldReference.Field.IsReadOnly) { - GetMarshalAsFromField(fieldReference, cancellationToken, argumentCount, diagnostics, location, marshalAsStrings); + GetMarshalAsFromField(fieldReference, cancellationToken, argumentCount, marshalAsStrings); + if (fieldReference.Field.Type is IArrayTypeSymbol && (marshalAsStrings.Count > 0)) + { + for (int count = marshalAsStrings.Count; count < argumentCount; ++count) + { + marshalAsStrings.Add(null); + } + } return; } IObjectCreationOperation? objectCreation = value as IObjectCreationOperation; @@ -183,7 +199,7 @@ private static void GetMarshalAsFromOperation(IOperation value, CancellationToke public static void GetMarshalAsFromOperation(IOperation value, CancellationToken cancellationToken, List diagnostics, Location location, out string? marshalAsString) { List marshalAsStrings = new(1); - GetMarshalAsFromOperation(value, cancellationToken, 0, diagnostics, location, marshalAsStrings); + GetMarshalAsFromOperation(value, cancellationToken, 1, diagnostics, location, marshalAsStrings); marshalAsString = marshalAsStrings.FirstOrDefault(); } } diff --git a/NativeGenericDelegateInfo.cs b/NativeGenericDelegateInfo.cs index 16d4e94..8cf7823 100644 --- a/NativeGenericDelegateInfo.cs +++ b/NativeGenericDelegateInfo.cs @@ -132,7 +132,8 @@ public NativeGenericDelegateInfo(MethodSymbolWithMarshalInfo methodSymbolWithMar FunctionPointerTypeArgumentsWithReturnType += ", void"; } cancellationToken.ThrowIfCancellationRequested(); - _ = sb.Append(andNewLine).Append($"MarshalInfo.Equals({(isAction ? "null" : "marshalReturnAs")}, marshalParamsAs, {marshalReturnAsAttribute}, {marshalParamsAsAttributes})"); + _ = sb.Append(andNewLine).Append($"MarshalInfo.Equals({(isAction ? "null" : "marshalReturnAs")}, {marshalReturnAsAttribute})") + .Append(andNewLine).Append($"MarshalInfo.PartiallyEquals(marshalParamsAs, {marshalParamsAsAttributes})"); TypeArgumentCheckWithMarshalInfoCondition = sb.ToString(); } } diff --git a/PartialImplementations.cs b/PartialImplementations.cs index 1164c39..d612fb0 100644 --- a/PartialImplementations.cs +++ b/PartialImplementations.cs @@ -188,35 +188,7 @@ namespace {Constants.RootNamespace} {{ file static class MarshalInfo {{ - internal static bool Equals(MarshalAsAttribute? marshalReturnAsLeft, MarshalAsAttribute?[]? marshalParamsAsLeft, MarshalAsAttribute? marshalReturnAsRight, MarshalAsAttribute?[]? marshalParamsAsRight) - {{ - if (!Equals(marshalReturnAsLeft, marshalReturnAsRight)) - {{ - return false; - }} - if (marshalParamsAsLeft is null) - {{ - return marshalParamsAsRight is null; - }} - else if (marshalParamsAsRight is null) - {{ - return false; - }} - if (marshalParamsAsLeft.Length != marshalParamsAsRight.Length) - {{ - return false; - }} - for (int i = 0; i < marshalParamsAsLeft.Length; ++i) - {{ - if (!Equals(marshalParamsAsLeft[i], marshalParamsAsRight[i])) - {{ - return false; - }} - }} - return true; - }} - - private static bool Equals(MarshalAsAttribute? left, MarshalAsAttribute? right) + internal static bool Equals(MarshalAsAttribute? left, MarshalAsAttribute? right) {{ if (left is null) {{ @@ -238,6 +210,30 @@ private static bool Equals(MarshalAsAttribute? left, MarshalAsAttribute? right) left.MarshalTypeRef == right.MarshalTypeRef && left.MarshalCookie == right.MarshalCookie; }} + + internal static bool PartiallyEquals(MarshalAsAttribute?[]? left, MarshalAsAttribute?[]? right) + {{ + if (left is null) + {{ + return right is null; + }} + int i = 0; + for (int len = Math.Min(left.Length, right?.Length ?? 0); i < len; ++i) + {{ + if (!Equals(left[i], right![i])) + {{ + return false; + }} + }} + for ( ; i < left.Length; ++i) + {{ + if (!Equals(left[i], null)) + {{ + return false; + }} + }} + return true; + }} }} {ConcreteClassDefinitions}{InterfaceImplementations}}}