Skip to content

Commit 6450f6b

Browse files
authored
Add an option to FriendlyOverloads to request previous pointer overloads (#1524)
1 parent 72e6dee commit 6450f6b

File tree

7 files changed

+96
-25
lines changed

7 files changed

+96
-25
lines changed

src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,33 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
8181
}
8282
}
8383

84-
bool useSpansForPointers = this.canUseSpan;
85-
foreach (MethodDeclarationSyntax method in this.DeclareFriendlyOverload(methodDefinition, externMethodDeclaration, declaringTypeName, overloadOf, helperMethodsAdded, avoidWinmdRootAlias, useSpansForPointers, omitOptionalParams: false))
84+
bool improvePointersToSpansAndRefs = this.canUseSpan;
85+
FriendlyMethodBookkeeping bookkeeping = new();
86+
foreach (MethodDeclarationSyntax method in this.DeclareFriendlyOverload(methodDefinition, externMethodDeclaration, declaringTypeName, overloadOf, helperMethodsAdded, avoidWinmdRootAlias, improvePointersToSpansAndRefs, omitOptionalParams: false, bookkeeping))
8687
{
8788
yield return method;
8889
}
90+
91+
if (this.Options.FriendlyOverloads.IncludePointerOverloads && improvePointersToSpansAndRefs && bookkeeping.NumSpanByteParameters > 0)
92+
{
93+
// If we could use Span and _did_ use span Span and the pointer overloads were requested, then Generate overloads that use pointer types instead of Span<byte>/ReadOnlySpan<byte>.
94+
foreach (MethodDeclarationSyntax method in this.DeclareFriendlyOverload(methodDefinition, externMethodDeclaration, declaringTypeName, overloadOf, helperMethodsAdded, avoidWinmdRootAlias, improvePointersToSpansAndRefs: false, omitOptionalParams: false))
95+
{
96+
yield return method;
97+
}
98+
}
8999
}
90100

91-
private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverload(MethodDefinition methodDefinition, MethodDeclarationSyntax externMethodDeclaration, NameSyntax declaringTypeName, FriendlyOverloadOf overloadOf, HashSet<string> helperMethodsAdded, bool avoidWinmdRootAlias, bool useSpansForPointers, bool omitOptionalParams)
101+
private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverload(
102+
MethodDefinition methodDefinition,
103+
MethodDeclarationSyntax externMethodDeclaration,
104+
NameSyntax declaringTypeName,
105+
FriendlyOverloadOf overloadOf,
106+
HashSet<string> helperMethodsAdded,
107+
bool avoidWinmdRootAlias,
108+
bool improvePointersToSpansAndRefs,
109+
bool omitOptionalParams,
110+
FriendlyMethodBookkeeping? bookkeeping = null)
92111
{
93112
#pragma warning disable SA1114 // Parameter list should follow declaration
94113
bool isReleaseMethod = this.MetadataIndex.ReleaseMethods.Contains(externMethodDeclaration.Identifier.ValueText);
@@ -123,6 +142,7 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverload(MethodDefin
123142
bool minorSignatureChange = false; // Did the signature change but not enough that overload resolution would be confused?
124143
List<Parameter>? countOfBytesStructParameters = null;
125144
int numOptionalParams = 0;
145+
int numSpanByteParameters = 0;
126146

127147
foreach (ParameterHandle paramHandle in methodDefinition.GetParameters())
128148
{
@@ -177,12 +197,18 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverload(MethodDefin
177197
// If there's no MemorySize attribute, we may still need to keep this parameter as a pointer if it's a struct with a flexible array.
178198
mustRemainAsPointer = parameterTypeInfo is PointerTypeHandleInfo { ElementType: HandleTypeHandleInfo pointedElement } && pointedElement.Generator.IsStructWithFlexibleArray(pointedElement);
179199
}
180-
else if (!useSpansForPointers)
200+
else if (!improvePointersToSpansAndRefs)
181201
{
182202
// If we are generating the overload with pointers for memory sized params then also force them to pointers.
183203
mustRemainAsPointer = true;
184204
}
185205

206+
// For compat with how out/ref parameters used to be generated, leave out/ref parameters as pointers if we're not tring to improve them.
207+
if (isOptional && isOut && !isComOutPtr && !improvePointersToSpansAndRefs)
208+
{
209+
mustRemainAsPointer = true;
210+
}
211+
186212
IdentifierNameSyntax origName = IdentifierName(externParam.Identifier.ValueText);
187213

188214
bool isArray = false;
@@ -220,7 +246,7 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverload(MethodDefin
220246
}
221247

222248
bool projectAsSpanBytes = false;
223-
if (useSpansForPointers && IsVoidPtrOrPtrPtr(externParam.Type))
249+
if (improvePointersToSpansAndRefs && IsVoidPtrOrPtrPtr(externParam.Type))
224250
{
225251
// if it's memory-sized project as Span<byte>
226252
if (memorySize is not null)
@@ -401,7 +427,7 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverload(MethodDefin
401427
arguments[param.SequenceNumber - 1] = Argument(typeDefHandleName);
402428
}
403429
else if ((externParam.Type is PointerTypeSyntax { ElementType: TypeSyntax ptrElementType }
404-
&& (!IsVoid(ptrElementType) || (useSpansForPointers && isArray))
430+
&& (!IsVoid(ptrElementType) || (improvePointersToSpansAndRefs && isArray))
405431
&& !this.IsInterface(parameterTypeInfo)) ||
406432
externParam.Type is ArrayTypeSyntax)
407433
{
@@ -476,6 +502,11 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverload(MethodDefin
476502
fixedBlocks.Add(VariableDeclaration(PointerType(elementType)).AddVariables(
477503
VariableDeclarator(localName.Identifier).WithInitializer(EqualsValueClause(origName))));
478504
arguments[param.SequenceNumber - 1] = projectAsSpanBytes ? Argument(CastExpression(externParam.Type, localName)) : Argument(localName);
505+
506+
if (projectAsSpanBytes)
507+
{
508+
numSpanByteParameters++;
509+
}
479510
}
480511
else if (isNullTerminated && isConst && parameters[param.SequenceNumber - 1].Type is PointerTypeSyntax { ElementType: PredefinedTypeSyntax { Keyword: { RawKind: (int)SyntaxKind.CharKeyword } } })
481512
{
@@ -826,6 +857,7 @@ bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource)
826857
fixedBlocks.Add(VariableDeclaration(PointerType(byteSyntax)).AddVariables(
827858
VariableDeclarator(localName.Identifier).WithInitializer(EqualsValueClause(origName))));
828859
arguments[param.SequenceNumber - 1] = Argument(CastExpression(externParam.Type, localName));
860+
numSpanByteParameters++;
829861
}
830862
else
831863
{
@@ -898,7 +930,7 @@ bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource)
898930
: IdentifierName(externMethodDeclaration.Identifier);
899931
SyntaxTrivia leadingTrivia = Trivia(
900932
DocumentationCommentTrivia(SyntaxKind.SingleLineDocumentationCommentTrivia).AddContent(
901-
XmlText("/// "),
933+
XmlText($"/// "),
902934
XmlEmptyElement("inheritdoc").AddAttributes(XmlCrefAttribute(NameMemberCref(docRefExternName, ToCref(externMethodDeclaration.ParameterList)))),
903935
XmlText().AddTextTokens(XmlTextNewLine("\n", continueXmlDocumentationComment: false))));
904936
InvocationExpressionSyntax externInvocation = InvocationExpression(
@@ -1003,12 +1035,17 @@ bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource)
10031035
friendlyDeclaration = friendlyDeclaration
10041036
.WithLeadingTrivia(leadingTrivia);
10051037

1038+
if (bookkeeping is not null)
1039+
{
1040+
bookkeeping.NumSpanByteParameters = numSpanByteParameters;
1041+
}
1042+
10061043
yield return friendlyDeclaration;
10071044

10081045
// We generated the main overload, but now see if we should generate another helper for things like SHGetFileInfo where
10091046
// there is a parameter that's sized in bytes and for convenience you want to just use the struct and not cast between Span<byte>.
10101047
// To avoid an explosion of overloads, just do this if there's one parameter of this kind.
1011-
if (useSpansForPointers && countOfBytesStructParameters?.Count == 1)
1048+
if (improvePointersToSpansAndRefs && countOfBytesStructParameters?.Count == 1)
10121049
{
10131050
MethodDeclarationSyntax? structOverload = this.DeclareStructCountOfBytesFriendlyOverload(externMethodDeclaration, countOfBytesStructParameters, friendlyDeclaration);
10141051
if (structOverload is not null)
@@ -1018,10 +1055,10 @@ bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource)
10181055
}
10191056
}
10201057

1021-
if (numOptionalParams > 0 && !omitOptionalParams)
1058+
if (numOptionalParams > 0 && !omitOptionalParams && improvePointersToSpansAndRefs)
10221059
{
10231060
// Generate overloads for optional parameters.
1024-
foreach (MethodDeclarationSyntax method in this.DeclareFriendlyOverload(methodDefinition, externMethodDeclaration, declaringTypeName, overloadOf, helperMethodsAdded, avoidWinmdRootAlias, useSpansForPointers, omitOptionalParams: true))
1061+
foreach (MethodDeclarationSyntax method in this.DeclareFriendlyOverload(methodDefinition, externMethodDeclaration, declaringTypeName, overloadOf, helperMethodsAdded, avoidWinmdRootAlias, improvePointersToSpansAndRefs, omitOptionalParams: true))
10251062
{
10261063
yield return method;
10271064
}
@@ -1163,4 +1200,9 @@ bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource)
11631200

11641201
return helper;
11651202
}
1203+
1204+
private class FriendlyMethodBookkeeping
1205+
{
1206+
public int NumSpanByteParameters { get; set; } = 0;
1207+
}
11661208
}

src/Microsoft.Windows.CsWin32/GeneratorOptions.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,5 +106,11 @@ public record FriendlyOverloadOptions
106106
/// </summary>
107107
/// <value>The default value is <see langword="true" />.</value>
108108
public bool Enabled { get; set; } = true;
109+
110+
/// <summary>
111+
/// Gets or sets a value indicating whether to also generate overloads that use pointer types for parameters that are [MemorySize] annotated buffers
112+
/// which normally appear as spans.
113+
/// </summary>
114+
public bool IncludePointerOverloads { get; set; } = false;
109115
}
110116
}

src/Microsoft.Windows.CsWin32/settings.schema.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@
5050
"description": "A value indicating whether to generate method overloads that may be easier to consume or be more idiomatic C#. These may use fewer pointers, accept or return SafeHandles, etc.",
5151
"type": "boolean",
5252
"default": true
53+
},
54+
"includePointerOverloads": {
55+
"description": "A value indicating whether to also generate overloads that use pointer types for parameters that are [MemorySize] annotated buffers which normally appear as spans.",
56+
"type": "boolean",
57+
"default": false
5358
}
5459
}
5560
},

test/CsWin32Generator.Tests/CsWin32GeneratorFullTests.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@ public CsWin32GeneratorFullTests(ITestOutputHelper logger)
1717
[Trait("TestCategory", "FailsInCloudTest")] // these take ~4GB of memory to run.
1818
[InlineData("net8.0", LanguageVersion.CSharp12)]
1919
[InlineData("net9.0", LanguageVersion.CSharp13)]
20-
public async Task FullGeneration(string tfm, LanguageVersion langVersion)
20+
[InlineData("net9.0", LanguageVersion.CSharp13, true)]
21+
public async Task FullGeneration(string tfm, LanguageVersion langVersion, bool includePointerOverloads = false)
2122
{
2223
this.fullGeneration = true;
2324
this.compilation = this.starterCompilations[tfm];
2425
this.parseOptions = this.parseOptions.WithLanguageVersion(langVersion);
25-
this.nativeMethodsJson = "NativeMethods.EmitSingleFile.json";
26+
this.nativeMethodsJson = includePointerOverloads ? "NativeMethods.IncludePointerOverloads.json" : "NativeMethods.EmitSingleFile.json";
2627
await this.InvokeGeneratorAndCompile($"FullGeneration_{tfm}_{langVersion}", TestOptions.None);
2728
}
2829
}

test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -184,27 +184,30 @@ public async Task TestGenerateCoCreateableClass()
184184
["IEnumString", "Next", "this winmdroot.System.Com.IEnumString @this, Span<winmdroot.Foundation.PWSTR> rgelt, out uint pceltFetched"],
185185
["PSCreateMemoryPropertyStore", "PSCreateMemoryPropertyStore", "in global::System.Guid riid, out void* ppv"],
186186
["DeviceIoControl", "DeviceIoControl", "SafeHandle hDevice, uint dwIoControlCode, ReadOnlySpan<byte> lpInBuffer, Span<byte> lpOutBuffer, out uint lpBytesReturned, global::System.Threading.NativeOverlapped* lpOverlapped"],
187+
["DeviceIoControl", "DeviceIoControl", "SafeHandle hDevice, uint dwIoControlCode, ReadOnlySpan<byte> lpInBuffer, Span<byte> lpOutBuffer, out uint lpBytesReturned, global::System.Threading.NativeOverlapped* lpOverlapped", true, "NativeMethods.IncludePointerOverloads.json"],
188+
["NtQueryObject", "NtQueryObject", "global::Windows.Win32.Foundation.HANDLE Handle, winmdroot.Foundation.OBJECT_INFORMATION_CLASS ObjectInformationClass, Span<byte> ObjectInformation, out uint ReturnLength"],
187189
];
188190

189191
[Theory]
190192
[MemberData(nameof(TestSignatureData))]
191-
public async Task VerifySignature(string api, string member, string signature, bool assertPresent = true)
193+
public async Task VerifySignature(string api, string member, string signature, bool assertPresent = true, string? nativeMethodsJson = null)
192194
{
193-
await this.VerifySignatureWorker(api, member, signature, assertPresent, "net9.0");
195+
await this.VerifySignatureWorker(api, member, signature, assertPresent, "net9.0", nativeMethodsJson);
194196
}
195197

196198
[Theory]
197199
[InlineData("InitializeAcl", "InitializeAcl", "out winmdroot.Security.ACL pAcl, winmdroot.Security.ACE_REVISION dwAclRevision", false)]
198-
public async Task VerifySignatureNet472(string api, string member, string signature, bool assertPresent = true)
200+
public async Task VerifySignatureNet472(string api, string member, string signature, bool assertPresent = true, string? nativeMethodsJson = null)
199201
{
200-
await this.VerifySignatureWorker(api, member, signature, assertPresent, "net472");
202+
await this.VerifySignatureWorker(api, member, signature, assertPresent, "net472", nativeMethodsJson);
201203
}
202204

203-
private async Task VerifySignatureWorker(string api, string member, string signature, bool assertPresent, string tfm)
205+
private async Task VerifySignatureWorker(string api, string member, string signature, bool assertPresent, string tfm, string? nativeMethodsJson)
204206
{
205207
this.tfm = tfm;
206208
this.compilation = this.starterCompilations[tfm];
207209
this.nativeMethods.Add(api);
210+
this.nativeMethodsJson = nativeMethodsJson;
208211

209212
// Make a unique name based on the signature
210213
await this.InvokeGeneratorAndCompile($"{api}_{member}_{tfm}_{signature.Select(x => (int)x).Aggregate((x, y) => x + y).ToString("X")}");
@@ -272,24 +275,25 @@ private async Task VerifySignatureWorker(string api, string member, string signa
272275
["GetModuleFileName", "Should have a friendly Span overload"],
273276
["PdhGetCounterInfo", "Optional out parameter omission conflicts with other overload"],
274277
["RtlUpcaseUnicodeChar", "char parameter should not get CharSet marshalling in AOT"],
275-
["CryptGetAsyncParam", "Has optional unmanaged delegate out param"]
278+
["CryptGetAsyncParam", "Has optional unmanaged delegate out param"],
279+
["NtQueryObject", "Verify pointer overloads and optional parameters", TestOptions.None, "NativeMethods.IncludePointerOverloads.json"],
276280
];
277281

278282
[Theory]
279283
[MemberData(nameof(TestApiData))]
280-
public async Task TestGenerateApi(string api, string purpose, TestOptions options = TestOptions.None)
284+
public async Task TestGenerateApi(string api, string purpose, TestOptions options = TestOptions.None, string? nativeMethodsJson = null)
281285
{
282-
await this.TestGenerateApiWorker(api, purpose, options, "net9.0");
286+
await this.TestGenerateApiWorker(api, purpose, options, "net9.0", nativeMethodsJson);
283287
}
284288

285289
[Theory]
286290
[MemberData(nameof(TestApiData))]
287-
public async Task TestGenerateApiNet8(string api, string purpose, TestOptions options = TestOptions.None)
291+
public async Task TestGenerateApiNet8(string api, string purpose, TestOptions options = TestOptions.None, string? nativeMethodsJson = null)
288292
{
289-
await this.TestGenerateApiWorker(api, purpose, options, "net8.0");
293+
await this.TestGenerateApiWorker(api, purpose, options, "net8.0", nativeMethodsJson);
290294
}
291295

292-
private async Task TestGenerateApiWorker(string api, string purpose, TestOptions options, string tfm)
296+
private async Task TestGenerateApiWorker(string api, string purpose, TestOptions options, string tfm, string? nativeMethodsJson)
293297
{
294298
LanguageVersion langVersion = (tfm == "net8.0") ? LanguageVersion.CSharp12 : LanguageVersion.CSharp13;
295299

@@ -298,6 +302,7 @@ private async Task TestGenerateApiWorker(string api, string purpose, TestOptions
298302
this.parseOptions = this.parseOptions.WithLanguageVersion(langVersion);
299303
this.Logger.WriteLine($"Testing {api} - {tfm} - {purpose}");
300304
this.nativeMethods.Add(api);
305+
this.nativeMethodsJson = nativeMethodsJson;
301306
await this.InvokeGeneratorAndCompile($"Test_{api}_{tfm}", options);
302307
}
303308

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"$schema": "..\\..\\..\\src\\Microsoft.Windows.CsWin32\\settings.schema.json",
3+
"emitSingleFile": true,
4+
"friendlyOverloads": {
5+
"includePointerOverloads": true
6+
}
7+
}

test/Microsoft.Windows.CsWin32.Tests/FullGenerationTests.cs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,11 @@ public void Constants()
5656
public void ExternMethods(
5757
MarshalingOptions marshaling,
5858
bool useIntPtrForComOutPtr,
59+
bool includePointerOverloads,
5960
[CombinatorialMemberData(nameof(SpecificCpuArchitectures))] Platform platform,
6061
[CombinatorialMemberData(nameof(TFMDataNoNetFx35))] string tfm)
6162
{
62-
this.TestHelper(OptionsForMarshaling(marshaling, useIntPtrForComOutPtr), platform, tfm, generator => generator.GenerateAllExternMethods(CancellationToken.None));
63+
this.TestHelper(OptionsForMarshaling(marshaling, useIntPtrForComOutPtr, includePointerOverloads), platform, tfm, generator => generator.GenerateAllExternMethods(CancellationToken.None));
6364
}
6465

6566
[Fact]
@@ -68,14 +69,18 @@ public void Macros()
6869
this.TestHelper(new GeneratorOptions(), Platform.X64, DefaultTFM, generator => generator.GenerateAllMacros(CancellationToken.None));
6970
}
7071

71-
private static GeneratorOptions OptionsForMarshaling(MarshalingOptions marshaling, bool useIntPtrForComOutPtr) => new()
72+
private static GeneratorOptions OptionsForMarshaling(MarshalingOptions marshaling, bool useIntPtrForComOutPtr, bool includePointerOverloads = false) => new()
7273
{
7374
AllowMarshaling = marshaling >= MarshalingOptions.MarshalingWithoutSafeHandles,
7475
UseSafeHandles = marshaling == MarshalingOptions.FullMarshaling,
7576
ComInterop = new()
7677
{
7778
UseIntPtrForComOutPointers = useIntPtrForComOutPtr,
7879
},
80+
FriendlyOverloads = new()
81+
{
82+
IncludePointerOverloads = includePointerOverloads,
83+
},
7984
};
8085

8186
private void TestHelper(GeneratorOptions generatorOptions, Platform platform, string targetFramework, Action<IGenerator> generationCommands)

0 commit comments

Comments
 (0)