Skip to content

Commit fe382a4

Browse files
authored
Improve optional out interface arguments (e.g. IWbemServices.GetObject) and other minor tweaks (#1544)
1 parent f8e0b22 commit fe382a4

File tree

11 files changed

+285
-15
lines changed

11 files changed

+285
-15
lines changed

src/CsWin32Generator/Program.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ private unsafe bool ProcessNativeMethodsFile(SuperGenerator superGenerator, File
503503
{
504504
lineNumber++;
505505
string name = line.Trim();
506-
if (string.IsNullOrWhiteSpace(name) || name.StartsWith("//", StringComparison.Ordinal) || name.StartsWith("-", StringComparison.Ordinal))
506+
if (string.IsNullOrWhiteSpace(name) || name.StartsWith("#", StringComparison.Ordinal) || name.StartsWith("//", StringComparison.Ordinal) || name.StartsWith("-", StringComparison.Ordinal))
507507
{
508508
skippedCount++;
509509
continue;

src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs

Lines changed: 167 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,16 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverload(
215215
mustRemainAsPointer = true;
216216
}
217217

218-
// For compat with how out/ref parameters used to be generated, leave out/ref parameters as pointers if we're not tring to improve them.
218+
if (isOptional && isIn && !isOut && externParam.Type is PointerTypeSyntax { ElementType: QualifiedNameSyntax elementTypeSyntax } && elementTypeSyntax.Right.Identifier.ValueText == "NativeOverlapped")
219+
{
220+
// OVERLAPPED struct must always be passed by pointer. Currently "in" optional parameters are promoted to nullable which
221+
// means the structs get copied. Normally this is fine since these struct addresses don't matter, but in the case of OVERLAPPED
222+
// it does. Trying to change "in" optional parameters to not be wrapped in nullable is a lot of work and impact for unclear value
223+
// so just adding special handling for OVERLAPPED for now.
224+
mustRemainAsPointer = true;
225+
}
226+
227+
// For compat with how out/ref parameters used to be generated, leave out/ref parameters as pointers if we're not trying to improve them.
219228
if (isOptional && isOut && !isComOutPtr && !improvePointersToSpansAndRefs)
220229
{
221230
mustRemainAsPointer = true;
@@ -311,7 +320,7 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverload(
311320
else if (omittableOptionalParam && omitOptionalParams)
312321
{
313322
// Remove the optional out parameter and supply the default value for the type to the extern method.
314-
if (externParamModifier.Kind() is SyntaxKind.OutKeyword || externParamModifier.Kind() is SyntaxKind.RefKeyword)
323+
if (externParamModifier.Kind() is SyntaxKind.OutKeyword or SyntaxKind.RefKeyword)
315324
{
316325
if (externParam.Type is PointerTypeSyntax || externParam.Type is FunctionPointerTypeSyntax)
317326
{
@@ -841,6 +850,130 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverload(
841850
RefExpression(localName),
842851
RefExpression(nullRef)));
843852
}
853+
else if (this.options.AllowMarshaling && isOptional && isOut && !isArray && parameterTypeInfo is PointerTypeHandleInfo pointerTypeHandleInfo && this.IsInterface(pointerTypeHandleInfo.ElementType))
854+
{
855+
// In source generated COM we can improve certain Optional out parameters to marshalled, e.g. IWbemServices.GetObject has some
856+
// optional out parameters that need to be pointers in the ABI but in the optional overload they can be marshalled to ComWrappers.
857+
TypeSyntax interfaceTypeSyntax = pointerTypeHandleInfo.ElementType.ToTypeSyntax(parameterTypeSyntaxSettings, GeneratingElement.FriendlyOverload, null).Type;
858+
parameters[paramIndex] = parameters[paramIndex]
859+
.WithType(interfaceTypeSyntax.WithTrailingTrivia(TriviaList(Space)))
860+
.WithModifiers(TokenList(TokenWithSpace(isIn && isOut ? SyntaxKind.RefKeyword : (isIn ? SyntaxKind.InKeyword : SyntaxKind.OutKeyword))));
861+
862+
TypeSyntax nativeInterfaceTypeSyntax = ((PointerTypeSyntax)externParam.Type).ElementType;
863+
864+
if (!isIn)
865+
{
866+
// Not a ref so need to assign first so we can use "ref" on the param. Use Unsafe.SkipInit(out origName) so that we can handle null refs.
867+
leadingOutsideTryStatements.Add(
868+
ExpressionStatement(
869+
InvocationExpression(
870+
MemberAccessExpression(
871+
SyntaxKind.SimpleMemberAccessExpression,
872+
IdentifierName(nameof(Unsafe)),
873+
IdentifierName(nameof(Unsafe.SkipInit))),
874+
ArgumentList().AddArguments(Argument(origName).WithRefKindKeyword(Token(SyntaxKind.OutKeyword))))));
875+
}
876+
877+
// For both in & out, declare a local:
878+
// externParamTypeInterface* __origName_native = null;
879+
IdentifierNameSyntax nativeLocal = IdentifierName($"__{origName.Identifier.ValueText.Replace("@", string.Empty)}_native");
880+
leadingOutsideTryStatements.Add(
881+
LocalDeclarationStatement(VariableDeclaration(nativeInterfaceTypeSyntax)
882+
.AddVariables(VariableDeclarator(nativeLocal.Identifier)
883+
.WithInitializer(EqualsValueClause(LiteralExpression(SyntaxKind.NullLiteralExpression))))));
884+
885+
// bool __origName_present = !Unsafe.IsNullRef<TInterface>(origName);
886+
string paramPresent = $"__{origName.Identifier.ValueText.Replace("@", string.Empty)}_present";
887+
leadingOutsideTryStatements.Add(
888+
LocalDeclarationStatement(VariableDeclaration(PredefinedType(Token(SyntaxKind.BoolKeyword)))
889+
.AddVariables(VariableDeclarator(Identifier(paramPresent))
890+
.WithInitializer(EqualsValueClause(
891+
PrefixUnaryExpression(
892+
SyntaxKind.LogicalNotExpression,
893+
InvocationExpression(
894+
MemberAccessExpression(
895+
SyntaxKind.SimpleMemberAccessExpression,
896+
IdentifierName(nameof(Unsafe)),
897+
GenericName(nameof(Unsafe.IsNullRef), TypeArgumentList().AddArguments(interfaceTypeSyntax))),
898+
ArgumentList().AddArguments(Argument(RefExpression(origName))))))))));
899+
900+
// If it's an in parameter, assign the native local from the managed parameter.
901+
// __origName_native = (TNative)global::System.Runtime.InteropServices.Marshalling.ComInterfaceMarshaller<TInterface>.ConvertToUnmanaged(origName);
902+
// Also remember the marshalled in pointer in case the callee modifies in for ref params.
903+
// __origName_nativeIn = __origName_native;
904+
if (isIn)
905+
{
906+
ExpressionSyntax toNativeExpression = this.useSourceGenerators ?
907+
InvocationExpression(
908+
MemberAccessExpression(
909+
SyntaxKind.SimpleMemberAccessExpression,
910+
GenericName($"global::System.Runtime.InteropServices.Marshalling.ComInterfaceMarshaller", TypeArgumentList().AddArguments(interfaceTypeSyntax)),
911+
IdentifierName("ConvertToUnmanaged")),
912+
ArgumentList().AddArguments(Argument(origName))) :
913+
ParenthesizedExpression(ConditionalExpression(
914+
BinaryExpression(SyntaxKind.NotEqualsExpression, origName, LiteralExpression(SyntaxKind.NullLiteralExpression)),
915+
CastExpression(
916+
PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword))),
917+
InvocationExpression(
918+
MemberAccessExpression(
919+
SyntaxKind.SimpleMemberAccessExpression,
920+
ParseTypeName($"global::System.Runtime.InteropServices.Marshal"),
921+
IdentifierName("GetIUnknownForObject")),
922+
ArgumentList().AddArguments(Argument(origName)))),
923+
LiteralExpression(SyntaxKind.NullLiteralExpression)));
924+
925+
leadingStatements.Add(
926+
IfStatement(
927+
IdentifierName(paramPresent),
928+
Block().AddStatements(
929+
ExpressionStatement(
930+
AssignmentExpression(
931+
SyntaxKind.SimpleAssignmentExpression,
932+
nativeLocal,
933+
CastExpression(
934+
nativeInterfaceTypeSyntax,
935+
toNativeExpression))))));
936+
}
937+
938+
// If it's an out parameter, assign the out parameter from the native local.
939+
// origName = global::System.Runtime.InteropServices.Marshalling.ComInterfaceMarshaller<TInterface>.ConvertToManaged(__origName_native);
940+
if (isOut)
941+
{
942+
ExpressionSyntax toManagedExpression = this.useSourceGenerators ?
943+
InvocationExpression(
944+
MemberAccessExpression(
945+
SyntaxKind.SimpleMemberAccessExpression,
946+
GenericName($"global::System.Runtime.InteropServices.Marshalling.ComInterfaceMarshaller", TypeArgumentList().AddArguments(interfaceTypeSyntax)),
947+
IdentifierName("ConvertToManaged")),
948+
ArgumentList().AddArguments(Argument(nativeLocal))) :
949+
ParenthesizedExpression(ConditionalExpression(
950+
BinaryExpression(SyntaxKind.NotEqualsExpression, nativeLocal, LiteralExpression(SyntaxKind.NullLiteralExpression)),
951+
CastExpression(interfaceTypeSyntax, InvocationExpression(
952+
MemberAccessExpression(
953+
SyntaxKind.SimpleMemberAccessExpression,
954+
ParseTypeName($"global::System.Runtime.InteropServices.Marshal"),
955+
IdentifierName("GetObjectForIUnknown")),
956+
ArgumentList().AddArguments(
957+
Argument(CastExpression(ParseName("nint"), nativeLocal))))),
958+
LiteralExpression(SyntaxKind.NullLiteralExpression)));
959+
960+
trailingStatements.Add(
961+
IfStatement(
962+
IdentifierName(paramPresent),
963+
Block().AddStatements(
964+
ExpressionStatement(
965+
AssignmentExpression(
966+
SyntaxKind.SimpleAssignmentExpression,
967+
origName,
968+
toManagedExpression)))));
969+
}
970+
971+
// Release the native pointers we have refs on.
972+
finallyStatements.Add(this.COMFreeNativePointerStatement(nativeLocal, interfaceTypeSyntax));
973+
974+
// If it's an in parameter, pass the native local as the argument.
975+
arguments[paramIndex] = arguments[paramIndex].WithExpression(ConditionalExpression(IdentifierName(paramPresent), PrefixUnaryExpression(SyntaxKind.AddressOfExpression, nativeLocal), LiteralExpression(SyntaxKind.NullLiteralExpression)));
976+
}
844977

845978
bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource)
846979
{
@@ -1236,6 +1369,38 @@ bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource)
12361369
return helper;
12371370
}
12381371

1372+
private StatementSyntax COMFreeNativePointerStatement(ExpressionSyntax nativePointer, TypeSyntax interfaceTypeSyntax)
1373+
{
1374+
if (this.useSourceGenerators)
1375+
{
1376+
// Release the nativeLocal via ComInterfaceMarshaller.Free.
1377+
return
1378+
ExpressionStatement(
1379+
InvocationExpression(
1380+
MemberAccessExpression(
1381+
SyntaxKind.SimpleMemberAccessExpression,
1382+
GenericName($"global::System.Runtime.InteropServices.Marshalling.ComInterfaceMarshaller", TypeArgumentList().AddArguments(interfaceTypeSyntax)),
1383+
IdentifierName("Free")),
1384+
ArgumentList().AddArguments(Argument(nativePointer))));
1385+
}
1386+
else
1387+
{
1388+
// Finally, release the nativeLocal via Marshal.Release.
1389+
return
1390+
IfStatement(
1391+
BinaryExpression(SyntaxKind.NotEqualsExpression, nativePointer, LiteralExpression(SyntaxKind.NullLiteralExpression)),
1392+
ExpressionStatement(
1393+
InvocationExpression(
1394+
MemberAccessExpression(
1395+
SyntaxKind.SimpleMemberAccessExpression,
1396+
ParseTypeName("global::System.Runtime.InteropServices.Marshal"),
1397+
IdentifierName("Release")),
1398+
ArgumentList().AddArguments(Argument(
1399+
CastExpression(ParseName("nint"), nativePointer))))))
1400+
.WithCloseParenToken(TokenWithLineFeed(SyntaxKind.CloseParenToken));
1401+
}
1402+
}
1403+
12391404
private class FriendlyMethodBookkeeping
12401405
{
12411406
public int NumSpanByteParameters { get; set; } = 0;

src/Microsoft.Windows.CsWin32/Generator.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1192,7 +1192,9 @@ internal void GetBaseTypeInfo(TypeDefinition typeDef, out StringHandle baseTypeN
11921192
throw new GenerationFailedException("Failed to parse template.");
11931193
}
11941194

1195-
specialDeclaration = specialDeclaration.WithAdditionalAnnotations(new SyntaxAnnotation(NamespaceContainerAnnotation, subNamespace));
1195+
specialDeclaration = specialDeclaration
1196+
.WithAdditionalAnnotations(new SyntaxAnnotation(NamespaceContainerAnnotation, subNamespace))
1197+
.AddAttributeLists(AttributeList().AddAttributes(GeneratedCodeAttribute));
11961198

11971199
this.volatileCode.AddSpecialType(specialName, specialDeclaration);
11981200
});

src/Microsoft.Windows.CsWin32/HandleTypeHandleInfo.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ internal override TypeSyntaxAndMarshaling ToTypeSyntax(TypeSyntaxSettings inputs
159159
return new TypeSyntaxAndMarshaling(IdentifierName(specialName));
160160
}
161161
}
162-
else if (useComSourceGenerators && simpleName is "VARIANT" && this.Generator.CanUseComVariant)
162+
else if (useComSourceGenerators && !inputs.AllowMarshaling && (simpleName is "VARIANT" or "VARIANT_unmanaged") && this.Generator.CanUseComVariant)
163163
{
164164
return new TypeSyntaxAndMarshaling(QualifiedName(ParseName("global::System.Runtime.InteropServices.Marshalling"), IdentifierName("ComVariant")));
165165
}

test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ public async Task TestGenerateIDispatch()
3131
var methods = idispatchType.SelectMany(t => t.DescendantNodes().OfType<MethodDeclarationSyntax>());
3232
var method = Assert.Single(methods, m => m.Identifier.Text == "GetTypeInfoCount");
3333
Assert.Contains("(out uint pctinfo)", method.ParameterList.ToString());
34+
35+
var invokeMethods = methods.Where(m => m.Identifier.Text == "Invoke");
36+
Assert.All(invokeMethods, m => Assert.DoesNotContain("VARIANT_unmanaged", m.ParameterList.ToString()));
3437
}
3538

3639
[Fact]
@@ -197,10 +200,13 @@ public async Task PointerReturnValueIsPreserved()
197200
["DeviceIoControl", "DeviceIoControl", "SafeHandle hDevice, uint dwIoControlCode, ReadOnlySpan<byte> lpInBuffer, Span<byte> lpOutBuffer, out uint lpBytesReturned, global::System.Threading.NativeOverlapped* lpOverlapped"],
198201
["DeviceIoControl", "DeviceIoControl", "SafeHandle hDevice, uint dwIoControlCode, ReadOnlySpan<byte> lpInBuffer, Span<byte> lpOutBuffer, out uint lpBytesReturned, global::System.Threading.NativeOverlapped* lpOverlapped", true, "NativeMethods.IncludePointerOverloads.json"],
199202
["NtQueryObject", "NtQueryObject", "global::Windows.Win32.Foundation.HANDLE Handle, winmdroot.Foundation.OBJECT_INFORMATION_CLASS ObjectInformationClass, Span<byte> ObjectInformation, out uint ReturnLength"],
200-
// ["IWbemServices", "GetObject", "winmdroot.Foundation.BSTR, winmdroot.System.Wmi.WBEM_GENERIC_FLAG_TYPE, winmdroot.System.Wmi.IWbemContext, out winmdroot.System.Wmi.IWbemClassObject ppObject, out winmdroot.System.Wmi.IWbemCallResult ppCallResult"],
201203
["ITypeInfo", "GetFuncDesc", "uint index, out winmdroot.System.Com.FUNCDESC_unmanaged* ppFuncDesc"],
202204
["ITsSbResourcePluginStore", "EnumerateTargets", "winmdroot.Foundation.BSTR FarmName, winmdroot.Foundation.BSTR EnvName, winmdroot.System.RemoteDesktop.TS_SB_SORT_BY sortByFieldId, winmdroot.Foundation.BSTR sortyByPropName, ref uint pdwCount, out winmdroot.System.RemoteDesktop.ITsSbTarget_unmanaged** pVal"],
203205
["MFEnumDeviceSources", "MFEnumDeviceSources", "winmdroot.Media.MediaFoundation.IMFAttributes pAttributes, out winmdroot.Media.MediaFoundation.IMFActivate_unmanaged** pppSourceActivate, out uint pcSourceActivate"],
206+
// Check that GetObject optional parameters got an overload with marshalled interface types
207+
["IWbemServices", "GetObject", "this winmdroot.System.Wmi.IWbemServices @this, SafeHandle strObjectPath, winmdroot.System.Wmi.WBEM_GENERIC_FLAG_TYPE lFlags, winmdroot.System.Wmi.IWbemContext pCtx, ref winmdroot.System.Wmi.IWbemClassObject ppObject, ref winmdroot.System.Wmi.IWbemCallResult ppCallResult"],
208+
// NativeOverlapped should be pointer even when not [Retained] as in CancelIoEx.
209+
["CancelIoEx", "CancelIoEx", "SafeHandle hFile, global::System.Threading.NativeOverlapped* lpOverlapped"],
204210
];
205211

206212
[Theory]

0 commit comments

Comments
 (0)