Skip to content
This repository was archived by the owner on Jan 23, 2023. It is now read-only.

Commit d7c967a

Browse files
Don't rely on the built-in marshaller during activation. (#28073)
Relying on the built-in marshaller leverages the Class interface approach which doesn't work for some interface types (e.g. interfaces inheriting from IDispatch). This approach is wrong regardless of why given that COM dictates the returned value must be properly cast the specific interface vtable. Updated tests so they would have found this issue.
1 parent 4cf9136 commit d7c967a

File tree

2 files changed

+43
-31
lines changed

2 files changed

+43
-31
lines changed

src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public interface IClassFactory
2525
void CreateInstance(
2626
[MarshalAs(UnmanagedType.Interface)] object? pUnkOuter,
2727
ref Guid riid,
28-
[MarshalAs(UnmanagedType.Interface)] out object? ppvObject);
28+
out IntPtr ppvObject);
2929

3030
void LockServer([MarshalAs(UnmanagedType.Bool)] bool fLock);
3131
}
@@ -51,7 +51,7 @@ internal interface IClassFactory2 : IClassFactory
5151
new void CreateInstance(
5252
[MarshalAs(UnmanagedType.Interface)] object? pUnkOuter,
5353
ref Guid riid,
54-
[MarshalAs(UnmanagedType.Interface)] out object? ppvObject);
54+
out IntPtr ppvObject);
5555

5656
new void LockServer([MarshalAs(UnmanagedType.Bool)] bool fLock);
5757

@@ -66,7 +66,7 @@ void CreateInstanceLic(
6666
[MarshalAs(UnmanagedType.Interface)] object? pUnkReserved,
6767
ref Guid riid,
6868
[MarshalAs(UnmanagedType.BStr)] string bstrKey,
69-
[MarshalAs(UnmanagedType.Interface)] out object ppvObject);
69+
out IntPtr ppvObject);
7070
}
7171

7272
[StructLayout(LayoutKind.Sequential)]
@@ -424,27 +424,31 @@ public static Type GetValidatedInterfaceType(Type classType, ref Guid riid, obje
424424
throw new InvalidCastException();
425425
}
426426

427-
public static void ValidateObjectIsMarshallableAsInterface(object obj, Type interfaceType)
427+
public static IntPtr GetObjectAsInterface(object obj, Type interfaceType)
428428
{
429-
// If the requested "interface type" is type object then return
430-
// because type object is always marshallable.
429+
// If the requested "interface type" is type object then return as IUnknown
431430
if (interfaceType == typeof(object))
432431
{
433-
return;
432+
return Marshal.GetIUnknownForObject(obj);
434433
}
435434

436435
Debug.Assert(interfaceType.IsInterface);
437436

438-
// The intent of this call is to validate the interface can be
437+
// The intent of this call is to get AND validate the interface can be
439438
// marshalled to native code. An exception will be thrown if the
440439
// type is unable to be marshalled to native code.
441440
// Scenarios where this is relevant:
442441
// - Interfaces that use Generics
443442
// - Interfaces that define implementation
444-
IntPtr ptr = Marshal.GetComInterfaceForObject(obj, interfaceType, CustomQueryInterfaceMode.Ignore);
443+
IntPtr interfaceMaybe = Marshal.GetComInterfaceForObject(obj, interfaceType, CustomQueryInterfaceMode.Ignore);
445444

446-
// Decrement the above 'Marshal.GetComInterfaceForObject()'
447-
Marshal.Release(ptr);
445+
if (interfaceMaybe == IntPtr.Zero)
446+
{
447+
// E_NOINTERFACE
448+
throw new InvalidCastException();
449+
}
450+
451+
return interfaceMaybe;
448452
}
449453

450454
public static object CreateAggregatedObject(object pUnkOuter, object comObject)
@@ -467,17 +471,17 @@ public static object CreateAggregatedObject(object pUnkOuter, object comObject)
467471
public void CreateInstance(
468472
[MarshalAs(UnmanagedType.Interface)] object? pUnkOuter,
469473
ref Guid riid,
470-
[MarshalAs(UnmanagedType.Interface)] out object? ppvObject)
474+
out IntPtr ppvObject)
471475
{
472476
Type interfaceType = BasicClassFactory.GetValidatedInterfaceType(_classType, ref riid, pUnkOuter);
473477

474-
ppvObject = Activator.CreateInstance(_classType)!;
478+
object obj = Activator.CreateInstance(_classType)!;
475479
if (pUnkOuter != null)
476480
{
477-
ppvObject = BasicClassFactory.CreateAggregatedObject(pUnkOuter, ppvObject);
481+
obj = BasicClassFactory.CreateAggregatedObject(pUnkOuter, obj);
478482
}
479483

480-
BasicClassFactory.ValidateObjectIsMarshallableAsInterface(ppvObject, interfaceType);
484+
ppvObject = BasicClassFactory.GetObjectAsInterface(obj, interfaceType);
481485
}
482486

483487
public void LockServer([MarshalAs(UnmanagedType.Bool)] bool fLock)
@@ -502,7 +506,7 @@ public LicenseClassFactory(Guid clsid, Type classType)
502506
public void CreateInstance(
503507
[MarshalAs(UnmanagedType.Interface)] object? pUnkOuter,
504508
ref Guid riid,
505-
[MarshalAs(UnmanagedType.Interface)] out object? ppvObject)
509+
out IntPtr ppvObject)
506510
{
507511
CreateInstanceInner(pUnkOuter, ref riid, key: null, isDesignTime: true, out ppvObject);
508512
}
@@ -535,7 +539,7 @@ public void CreateInstanceLic(
535539
[MarshalAs(UnmanagedType.Interface)] object? pUnkReserved,
536540
ref Guid riid,
537541
[MarshalAs(UnmanagedType.BStr)] string bstrKey,
538-
[MarshalAs(UnmanagedType.Interface)] out object ppvObject)
542+
out IntPtr ppvObject)
539543
{
540544
Debug.Assert(pUnkReserved == null);
541545
CreateInstanceInner(pUnkOuter, ref riid, bstrKey, isDesignTime: false, out ppvObject);
@@ -546,17 +550,17 @@ private void CreateInstanceInner(
546550
ref Guid riid,
547551
string? key,
548552
bool isDesignTime,
549-
out object ppvObject)
553+
out IntPtr ppvObject)
550554
{
551555
Type interfaceType = BasicClassFactory.GetValidatedInterfaceType(_classType, ref riid, pUnkOuter);
552556

553-
ppvObject = _licenseProxy.AllocateAndValidateLicense(_classType, key, isDesignTime);
557+
object obj = _licenseProxy.AllocateAndValidateLicense(_classType, key, isDesignTime);
554558
if (pUnkOuter != null)
555559
{
556-
ppvObject = BasicClassFactory.CreateAggregatedObject(pUnkOuter, ppvObject);
560+
obj = BasicClassFactory.CreateAggregatedObject(pUnkOuter, obj);
557561
}
558562

559-
BasicClassFactory.ValidateObjectIsMarshallableAsInterface(ppvObject, interfaceType);
563+
ppvObject = BasicClassFactory.GetObjectAsInterface(obj, interfaceType);
560564
}
561565
}
562566
}

tests/src/Interop/COM/Activator/Program.cs

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,11 @@ static void ValidateAssemblyIsolation()
106106

107107
var factory = (IClassFactory)ComActivator.GetClassFactoryForType(cxt);
108108

109-
object svr;
110-
factory.CreateInstance(null, ref iid, out svr);
111-
typeCFromAssemblyA = (Type)((IGetTypeFromC)svr).GetTypeFromC();
109+
IntPtr svrRaw;
110+
factory.CreateInstance(null, ref iid, out svrRaw);
111+
var svr = (IGetTypeFromC)Marshal.GetObjectForIUnknown(svrRaw);
112+
Marshal.Release(svrRaw);
113+
typeCFromAssemblyA = (Type)svr.GetTypeFromC();
112114
}
113115

114116
using (HostPolicyMock.Mock_corehost_resolve_component_dependencies(
@@ -128,9 +130,11 @@ static void ValidateAssemblyIsolation()
128130

129131
var factory = (IClassFactory)ComActivator.GetClassFactoryForType(cxt);
130132

131-
object svr;
132-
factory.CreateInstance(null, ref iid, out svr);
133-
typeCFromAssemblyB = (Type)((IGetTypeFromC)svr).GetTypeFromC();
133+
IntPtr svrRaw;
134+
factory.CreateInstance(null, ref iid, out svrRaw);
135+
var svr = (IGetTypeFromC)Marshal.GetObjectForIUnknown(svrRaw);
136+
Marshal.Release(svrRaw);
137+
typeCFromAssemblyB = (Type)svr.GetTypeFromC();
134138
}
135139

136140
Assert.AreNotEqual(typeCFromAssemblyA, typeCFromAssemblyB, "Types should be from different AssemblyLoadContexts");
@@ -172,8 +176,10 @@ static void ValidateUserDefinedRegistrationCallbacks()
172176

173177
var factory = (IClassFactory)ComActivator.GetClassFactoryForType(cxt);
174178

175-
object svr;
176-
factory.CreateInstance(null, ref iid, out svr);
179+
IntPtr svrRaw;
180+
factory.CreateInstance(null, ref iid, out svrRaw);
181+
var svr = Marshal.GetObjectForIUnknown(svrRaw);
182+
Marshal.Release(svrRaw);
177183

178184
var inst = (IValidateRegistrationCallbacks)svr;
179185
Assert.IsFalse(inst.DidRegister());
@@ -209,8 +215,10 @@ static void ValidateUserDefinedRegistrationCallbacks()
209215

210216
var factory = (IClassFactory)ComActivator.GetClassFactoryForType(cxt);
211217

212-
object svr;
213-
factory.CreateInstance(null, ref iid, out svr);
218+
IntPtr svrRaw;
219+
factory.CreateInstance(null, ref iid, out svrRaw);
220+
var svr = Marshal.GetObjectForIUnknown(svrRaw);
221+
Marshal.Release(svrRaw);
214222

215223
var inst = (IValidateRegistrationCallbacks)svr;
216224
cxt.InterfaceId = Guid.Empty;

0 commit comments

Comments
 (0)