@@ -215,7 +215,16 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverload(
215215 mustRemainAsPointer = true ;
216216 }
217217
218- // For compat with how out/ref parameters used to be generated, leave out/ref parameters as pointers if we're not tring to improve them.
218+ if ( isOptional && isIn && ! isOut && externParam . Type is PointerTypeSyntax { ElementType : QualifiedNameSyntax elementTypeSyntax } && elementTypeSyntax . Right . Identifier . ValueText == "NativeOverlapped" )
219+ {
220+ // OVERLAPPED struct must always be passed by pointer. Currently "in" optional parameters are promoted to nullable which
221+ // means the structs get copied. Normally this is fine since these struct addresses don't matter, but in the case of OVERLAPPED
222+ // it does. Trying to change "in" optional parameters to not be wrapped in nullable is a lot of work and impact for unclear value
223+ // so just adding special handling for OVERLAPPED for now.
224+ mustRemainAsPointer = true ;
225+ }
226+
227+ // For compat with how out/ref parameters used to be generated, leave out/ref parameters as pointers if we're not trying to improve them.
219228 if ( isOptional && isOut && ! isComOutPtr && ! improvePointersToSpansAndRefs )
220229 {
221230 mustRemainAsPointer = true ;
@@ -311,7 +320,7 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverload(
311320 else if ( omittableOptionalParam && omitOptionalParams )
312321 {
313322 // Remove the optional out parameter and supply the default value for the type to the extern method.
314- if ( externParamModifier . Kind ( ) is SyntaxKind . OutKeyword || externParamModifier . Kind ( ) is SyntaxKind . RefKeyword )
323+ if ( externParamModifier . Kind ( ) is SyntaxKind . OutKeyword or SyntaxKind . RefKeyword )
315324 {
316325 if ( externParam . Type is PointerTypeSyntax || externParam . Type is FunctionPointerTypeSyntax )
317326 {
@@ -841,6 +850,130 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverload(
841850 RefExpression ( localName ) ,
842851 RefExpression ( nullRef ) ) ) ;
843852 }
853+ else if ( this . options . AllowMarshaling && isOptional && isOut && ! isArray && parameterTypeInfo is PointerTypeHandleInfo pointerTypeHandleInfo && this . IsInterface ( pointerTypeHandleInfo . ElementType ) )
854+ {
855+ // In source generated COM we can improve certain Optional out parameters to marshalled, e.g. IWbemServices.GetObject has some
856+ // optional out parameters that need to be pointers in the ABI but in the optional overload they can be marshalled to ComWrappers.
857+ TypeSyntax interfaceTypeSyntax = pointerTypeHandleInfo . ElementType . ToTypeSyntax ( parameterTypeSyntaxSettings , GeneratingElement . FriendlyOverload , null ) . Type ;
858+ parameters [ paramIndex ] = parameters [ paramIndex ]
859+ . WithType ( interfaceTypeSyntax . WithTrailingTrivia ( TriviaList ( Space ) ) )
860+ . WithModifiers ( TokenList ( TokenWithSpace ( isIn && isOut ? SyntaxKind . RefKeyword : ( isIn ? SyntaxKind . InKeyword : SyntaxKind . OutKeyword ) ) ) ) ;
861+
862+ TypeSyntax nativeInterfaceTypeSyntax = ( ( PointerTypeSyntax ) externParam . Type ) . ElementType ;
863+
864+ if ( ! isIn )
865+ {
866+ // Not a ref so need to assign first so we can use "ref" on the param. Use Unsafe.SkipInit(out origName) so that we can handle null refs.
867+ leadingOutsideTryStatements . Add (
868+ ExpressionStatement (
869+ InvocationExpression (
870+ MemberAccessExpression (
871+ SyntaxKind . SimpleMemberAccessExpression ,
872+ IdentifierName ( nameof ( Unsafe ) ) ,
873+ IdentifierName ( nameof ( Unsafe . SkipInit ) ) ) ,
874+ ArgumentList ( ) . AddArguments ( Argument ( origName ) . WithRefKindKeyword ( Token ( SyntaxKind . OutKeyword ) ) ) ) ) ) ;
875+ }
876+
877+ // For both in & out, declare a local:
878+ // externParamTypeInterface* __origName_native = null;
879+ IdentifierNameSyntax nativeLocal = IdentifierName ( $ "__{ origName . Identifier . ValueText . Replace ( "@" , string . Empty ) } _native") ;
880+ leadingOutsideTryStatements . Add (
881+ LocalDeclarationStatement ( VariableDeclaration ( nativeInterfaceTypeSyntax )
882+ . AddVariables ( VariableDeclarator ( nativeLocal . Identifier )
883+ . WithInitializer ( EqualsValueClause ( LiteralExpression ( SyntaxKind . NullLiteralExpression ) ) ) ) ) ) ;
884+
885+ // bool __origName_present = !Unsafe.IsNullRef<TInterface>(origName);
886+ string paramPresent = $ "__{ origName . Identifier . ValueText . Replace ( "@" , string . Empty ) } _present";
887+ leadingOutsideTryStatements . Add (
888+ LocalDeclarationStatement ( VariableDeclaration ( PredefinedType ( Token ( SyntaxKind . BoolKeyword ) ) )
889+ . AddVariables ( VariableDeclarator ( Identifier ( paramPresent ) )
890+ . WithInitializer ( EqualsValueClause (
891+ PrefixUnaryExpression (
892+ SyntaxKind . LogicalNotExpression ,
893+ InvocationExpression (
894+ MemberAccessExpression (
895+ SyntaxKind . SimpleMemberAccessExpression ,
896+ IdentifierName ( nameof ( Unsafe ) ) ,
897+ GenericName ( nameof ( Unsafe . IsNullRef ) , TypeArgumentList ( ) . AddArguments ( interfaceTypeSyntax ) ) ) ,
898+ ArgumentList ( ) . AddArguments ( Argument ( RefExpression ( origName ) ) ) ) ) ) ) ) ) ) ;
899+
900+ // If it's an in parameter, assign the native local from the managed parameter.
901+ // __origName_native = (TNative)global::System.Runtime.InteropServices.Marshalling.ComInterfaceMarshaller<TInterface>.ConvertToUnmanaged(origName);
902+ // Also remember the marshalled in pointer in case the callee modifies in for ref params.
903+ // __origName_nativeIn = __origName_native;
904+ if ( isIn )
905+ {
906+ ExpressionSyntax toNativeExpression = this . useSourceGenerators ?
907+ InvocationExpression (
908+ MemberAccessExpression (
909+ SyntaxKind . SimpleMemberAccessExpression ,
910+ GenericName ( $ "global::System.Runtime.InteropServices.Marshalling.ComInterfaceMarshaller", TypeArgumentList ( ) . AddArguments ( interfaceTypeSyntax ) ) ,
911+ IdentifierName ( "ConvertToUnmanaged" ) ) ,
912+ ArgumentList ( ) . AddArguments ( Argument ( origName ) ) ) :
913+ ParenthesizedExpression ( ConditionalExpression (
914+ BinaryExpression ( SyntaxKind . NotEqualsExpression , origName , LiteralExpression ( SyntaxKind . NullLiteralExpression ) ) ,
915+ CastExpression (
916+ PointerType ( PredefinedType ( Token ( SyntaxKind . VoidKeyword ) ) ) ,
917+ InvocationExpression (
918+ MemberAccessExpression (
919+ SyntaxKind . SimpleMemberAccessExpression ,
920+ ParseTypeName ( $ "global::System.Runtime.InteropServices.Marshal") ,
921+ IdentifierName ( "GetIUnknownForObject" ) ) ,
922+ ArgumentList ( ) . AddArguments ( Argument ( origName ) ) ) ) ,
923+ LiteralExpression ( SyntaxKind . NullLiteralExpression ) ) ) ;
924+
925+ leadingStatements . Add (
926+ IfStatement (
927+ IdentifierName ( paramPresent ) ,
928+ Block ( ) . AddStatements (
929+ ExpressionStatement (
930+ AssignmentExpression (
931+ SyntaxKind . SimpleAssignmentExpression ,
932+ nativeLocal ,
933+ CastExpression (
934+ nativeInterfaceTypeSyntax ,
935+ toNativeExpression ) ) ) ) ) ) ;
936+ }
937+
938+ // If it's an out parameter, assign the out parameter from the native local.
939+ // origName = global::System.Runtime.InteropServices.Marshalling.ComInterfaceMarshaller<TInterface>.ConvertToManaged(__origName_native);
940+ if ( isOut )
941+ {
942+ ExpressionSyntax toManagedExpression = this . useSourceGenerators ?
943+ InvocationExpression (
944+ MemberAccessExpression (
945+ SyntaxKind . SimpleMemberAccessExpression ,
946+ GenericName ( $ "global::System.Runtime.InteropServices.Marshalling.ComInterfaceMarshaller", TypeArgumentList ( ) . AddArguments ( interfaceTypeSyntax ) ) ,
947+ IdentifierName ( "ConvertToManaged" ) ) ,
948+ ArgumentList ( ) . AddArguments ( Argument ( nativeLocal ) ) ) :
949+ ParenthesizedExpression ( ConditionalExpression (
950+ BinaryExpression ( SyntaxKind . NotEqualsExpression , nativeLocal , LiteralExpression ( SyntaxKind . NullLiteralExpression ) ) ,
951+ CastExpression ( interfaceTypeSyntax , InvocationExpression (
952+ MemberAccessExpression (
953+ SyntaxKind . SimpleMemberAccessExpression ,
954+ ParseTypeName ( $ "global::System.Runtime.InteropServices.Marshal") ,
955+ IdentifierName ( "GetObjectForIUnknown" ) ) ,
956+ ArgumentList ( ) . AddArguments (
957+ Argument ( CastExpression ( ParseName ( "nint" ) , nativeLocal ) ) ) ) ) ,
958+ LiteralExpression ( SyntaxKind . NullLiteralExpression ) ) ) ;
959+
960+ trailingStatements . Add (
961+ IfStatement (
962+ IdentifierName ( paramPresent ) ,
963+ Block ( ) . AddStatements (
964+ ExpressionStatement (
965+ AssignmentExpression (
966+ SyntaxKind . SimpleAssignmentExpression ,
967+ origName ,
968+ toManagedExpression ) ) ) ) ) ;
969+ }
970+
971+ // Release the native pointers we have refs on.
972+ finallyStatements . Add ( this . COMFreeNativePointerStatement ( nativeLocal , interfaceTypeSyntax ) ) ;
973+
974+ // If it's an in parameter, pass the native local as the argument.
975+ arguments [ paramIndex ] = arguments [ paramIndex ] . WithExpression ( ConditionalExpression ( IdentifierName ( paramPresent ) , PrefixUnaryExpression ( SyntaxKind . AddressOfExpression , nativeLocal ) , LiteralExpression ( SyntaxKind . NullLiteralExpression ) ) ) ;
976+ }
844977
845978 bool TryHandleCountParam ( TypeSyntax elementType , bool nullableSource )
846979 {
@@ -1236,6 +1369,38 @@ bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource)
12361369 return helper ;
12371370 }
12381371
1372+ private StatementSyntax COMFreeNativePointerStatement ( ExpressionSyntax nativePointer , TypeSyntax interfaceTypeSyntax )
1373+ {
1374+ if ( this . useSourceGenerators )
1375+ {
1376+ // Release the nativeLocal via ComInterfaceMarshaller.Free.
1377+ return
1378+ ExpressionStatement (
1379+ InvocationExpression (
1380+ MemberAccessExpression (
1381+ SyntaxKind . SimpleMemberAccessExpression ,
1382+ GenericName ( $ "global::System.Runtime.InteropServices.Marshalling.ComInterfaceMarshaller", TypeArgumentList ( ) . AddArguments ( interfaceTypeSyntax ) ) ,
1383+ IdentifierName ( "Free" ) ) ,
1384+ ArgumentList ( ) . AddArguments ( Argument ( nativePointer ) ) ) ) ;
1385+ }
1386+ else
1387+ {
1388+ // Finally, release the nativeLocal via Marshal.Release.
1389+ return
1390+ IfStatement (
1391+ BinaryExpression ( SyntaxKind . NotEqualsExpression , nativePointer , LiteralExpression ( SyntaxKind . NullLiteralExpression ) ) ,
1392+ ExpressionStatement (
1393+ InvocationExpression (
1394+ MemberAccessExpression (
1395+ SyntaxKind . SimpleMemberAccessExpression ,
1396+ ParseTypeName ( "global::System.Runtime.InteropServices.Marshal" ) ,
1397+ IdentifierName ( "Release" ) ) ,
1398+ ArgumentList ( ) . AddArguments ( Argument (
1399+ CastExpression ( ParseName ( "nint" ) , nativePointer ) ) ) ) ) )
1400+ . WithCloseParenToken ( TokenWithLineFeed ( SyntaxKind . CloseParenToken ) ) ;
1401+ }
1402+ }
1403+
12391404 private class FriendlyMethodBookkeeping
12401405 {
12411406 public int NumSpanByteParameters { get ; set ; } = 0 ;
0 commit comments