Skip to content

Commit 8158b01

Browse files
authored
Add an OLE proxy mock (#13115)
This adds manual CCW creation for emulating OLE proxy behavior for unit tests. This will allow testing closer to actual runtime behavior. MockOleServices now uses the proxy.
1 parent edc442e commit 8158b01

File tree

6 files changed

+366
-23
lines changed

6 files changed

+366
-23
lines changed
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System.Runtime.CompilerServices;
5+
using System.Runtime.InteropServices;
6+
7+
namespace Windows.Win32.System.Com;
8+
9+
internal unsafe partial struct IUnknown
10+
{
11+
/// <summary>
12+
/// Manual COM Callable Wrapper for <see cref="IUnknown"/>.
13+
/// </summary>
14+
/// <remarks>
15+
/// <para>
16+
/// This is for test and debug scenarios only. It should not be used directly in the product.
17+
/// </para>
18+
/// </remarks>
19+
/// <devdoc>
20+
/// This is a simplified version of what <see cref="ComWrappers"/> does. It is useful when we want to manage
21+
/// our own <see cref="IUnknown.QueryInterface(Guid*, void**)"/> handling for debugging and testing purposes.
22+
/// </devdoc>
23+
internal static class CCW
24+
{
25+
private static readonly Vtbl* s_vtable = AllocateVTable();
26+
27+
private static unsafe Vtbl* AllocateVTable()
28+
{
29+
// Allocate and create a singular VTable for this type projection.
30+
Vtbl* vtable = (Vtbl*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(CCW), sizeof(Vtbl));
31+
32+
// IUnknown
33+
vtable->QueryInterface_1 = &QueryInterface;
34+
vtable->AddRef_2 = &AddRef;
35+
vtable->Release_3 = &Release;
36+
return vtable;
37+
}
38+
39+
/// <inheritdoc cref="CCW"/>
40+
/// <summary>
41+
/// Creates a manual COM Callable Wrapper for the given <paramref name="object"/>.
42+
/// </summary>
43+
public static unsafe IUnknown* Create(Interface @object) =>
44+
(IUnknown*)Lifetime<Vtbl, Interface>.Allocate(@object, s_vtable);
45+
46+
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvStdcall)])]
47+
private static unsafe HRESULT QueryInterface(IUnknown* @this, Guid* iid, void* ppObject)
48+
{
49+
if (iid is null || ppObject is null)
50+
{
51+
return HRESULT.E_POINTER;
52+
}
53+
54+
if (iid->Equals(IID_Guid))
55+
{
56+
ppObject = @this;
57+
}
58+
else
59+
{
60+
ppObject = null;
61+
return HRESULT.E_NOINTERFACE;
62+
}
63+
64+
Lifetime<Vtbl, Interface>.AddRef(@this);
65+
return HRESULT.S_OK;
66+
}
67+
68+
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvStdcall)])]
69+
private static unsafe uint AddRef(IUnknown* @this) => Lifetime<Vtbl, Interface>.AddRef(@this);
70+
71+
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvStdcall)])]
72+
private static unsafe uint Release(IUnknown* @this) => Lifetime<Vtbl, Interface>.Release(@this);
73+
}
74+
}

src/System.Private.Windows.Core/src/Windows/Win32/System/Com/IUnknown.cs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,8 @@ internal unsafe partial struct IUnknown
88
// https://github.com/microsoft/CsWin32/issues/724
99
internal interface Interface
1010
{
11-
// Can't do this yet as the generated members aren't public. Creating it anyway to help constrain our
12-
// helpers a bit more.
13-
// https://github.com/microsoft/CsWin32/issues/723
14-
15-
// internal unsafe HRESULT QueryInterface(Guid* riid, void** ppvObject);
16-
// internal uint AddRef();
17-
// internal uint Release();
11+
internal unsafe HRESULT QueryInterface(Guid* riid, void** ppvObject);
12+
internal uint AddRef();
13+
internal uint Release();
1814
}
1915
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System.Runtime.InteropServices;
5+
6+
namespace Windows.Win32.System.Com;
7+
8+
/// <summary>
9+
/// Struct that handles managed object COM projection lifetime management.
10+
/// </summary>
11+
/// <typeparam name="TVTable">
12+
/// Struct that repesents the VTable for a COM pointer.
13+
/// </typeparam>
14+
/// <typeparam name="TObject">
15+
/// The type of object being projected.
16+
/// </typeparam>
17+
internal unsafe struct Lifetime<TVTable, TObject> where TVTable : unmanaged
18+
{
19+
private TVTable* _vtable;
20+
private void* _handle;
21+
private uint _refCount;
22+
23+
public static unsafe uint AddRef(void* @this) =>
24+
Interlocked.Increment(ref ((Lifetime<TVTable, TObject>*)@this)->_refCount);
25+
26+
public static unsafe uint Release(void* @this)
27+
{
28+
var lifetime = (Lifetime<TVTable, TObject>*)@this;
29+
Debug.Assert(lifetime->_refCount > 0);
30+
uint count = Interlocked.Decrement(ref lifetime->_refCount);
31+
if (count == 0)
32+
{
33+
GCHandle.FromIntPtr((nint)lifetime->_handle).Free();
34+
Marshal.FreeCoTaskMem((nint)lifetime);
35+
}
36+
37+
return count;
38+
}
39+
40+
/// <summary>
41+
/// Allocate a lifetime wrapper for the given <paramref name="object"/> with the given
42+
/// <paramref name="vtable"/>.
43+
/// </summary>
44+
/// <remarks>
45+
/// <para>
46+
/// This creates a <see cref="GCHandle"/> to root the <paramref name="object"/> until ref
47+
/// counting has gone to zero.
48+
/// </para>
49+
/// <para>
50+
/// The <paramref name="vtable"/> should be fixed, typically as a static. Com calls always
51+
/// include the "this" pointer as the first argument.
52+
/// </para>
53+
/// </remarks>
54+
public static unsafe Lifetime<TVTable, TObject>* Allocate(TObject @object, TVTable* vtable)
55+
{
56+
var wrapper = (Lifetime<TVTable, TObject>*)Marshal.AllocCoTaskMem(sizeof(Lifetime<TVTable, TObject>));
57+
58+
// Create the wrapper instance.
59+
wrapper->_vtable = vtable;
60+
wrapper->_handle = (void*)GCHandle.ToIntPtr(GCHandle.Alloc(@object));
61+
wrapper->_refCount = 1;
62+
63+
return wrapper;
64+
}
65+
66+
/// <summary>
67+
/// Get the object associated with this lifetime.
68+
/// </summary>
69+
/// <param name="this">
70+
/// The passed back "this" pointer that originally came from <see cref="Allocate(TObject, TVTable*)"/>.
71+
/// </param>
72+
/// <returns>The object associated with this lifetime, if any.</returns>
73+
/// <exception cref="InvalidOperationException">The handle was freed.</exception>
74+
public static TObject? GetObject(void* @this)
75+
{
76+
var lifetime = (Lifetime<TVTable, TObject>*)@this;
77+
return (TObject?)GCHandle.FromIntPtr((IntPtr)lifetime->_handle).Target;
78+
}
79+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System.Runtime.CompilerServices;
5+
using System.Runtime.InteropServices;
6+
using Windows.Win32.Foundation;
7+
using Windows.Win32.System.Com;
8+
using Lifetime = Windows.Win32.System.Com.Lifetime<Windows.Win32.System.Com.IDataObject.Vtbl, System.Private.Windows.Ole.DataObjectProxy>;
9+
10+
namespace System.Private.Windows.Ole;
11+
12+
/// <summary>
13+
/// Emulates an OLE proxy for unit testing purposes.
14+
/// </summary>
15+
internal unsafe class DataObjectProxy : IDataObject.Interface, IDisposable
16+
{
17+
// Agile for ensured cleanup
18+
private readonly AgileComPointer<IDataObject> _agileOriginal;
19+
private readonly IDataObject* _original;
20+
private readonly AgileComPointer<IDataObject> _agileProxy;
21+
public IDataObject* Proxy { get; }
22+
23+
public DataObjectProxy(IDataObject* original)
24+
{
25+
// Don't track disposal, we depend on finalization for testing.
26+
27+
_original = original;
28+
_agileOriginal = new(
29+
#if DEBUG
30+
original, takeOwnership: true, trackDisposal: false
31+
#else
32+
original, takeOwnership: true
33+
#endif
34+
);
35+
36+
Proxy = CCW.Create(this);
37+
38+
_agileProxy = new(
39+
#if DEBUG
40+
Proxy, takeOwnership: true, trackDisposal: false
41+
#else
42+
Proxy, takeOwnership: true
43+
#endif
44+
);
45+
}
46+
47+
public HRESULT GetData(FORMATETC* pformatetcIn, STGMEDIUM* pmedium) => _original->GetData(pformatetcIn, pmedium);
48+
public HRESULT GetDataHere(FORMATETC* pformatetc, STGMEDIUM* pmedium) => _original->GetDataHere(pformatetc, pmedium);
49+
public HRESULT QueryGetData(FORMATETC* pformatetc) => _original->QueryGetData(pformatetc);
50+
public HRESULT GetCanonicalFormatEtc(FORMATETC* pformatectIn, FORMATETC* pformatetcOut) => _original->GetCanonicalFormatEtc(pformatectIn, pformatetcOut);
51+
public HRESULT SetData(FORMATETC* pformatetc, STGMEDIUM* pmedium, BOOL fRelease) => _original->SetData(pformatetc, pmedium, fRelease);
52+
public HRESULT EnumFormatEtc(uint dwDirection, IEnumFORMATETC** ppenumFormatEtc) => _original->EnumFormatEtc(dwDirection, ppenumFormatEtc);
53+
public HRESULT DAdvise(FORMATETC* pformatetc, uint advf, IAdviseSink* pAdvSink, uint* pdwConnection) => _original->DAdvise(pformatetc, advf, pAdvSink, pdwConnection);
54+
public HRESULT DUnadvise(uint dwConnection) => _original->DUnadvise(dwConnection);
55+
public HRESULT EnumDAdvise(IEnumSTATDATA** ppenumAdvise) => _original->EnumDAdvise(ppenumAdvise);
56+
57+
public void Dispose()
58+
{
59+
_agileOriginal.Dispose();
60+
_agileProxy.Dispose();
61+
}
62+
63+
internal static class CCW
64+
{
65+
private static readonly IDataObject.Vtbl* s_vtable = AllocateVTable();
66+
67+
private static unsafe IDataObject.Vtbl* AllocateVTable()
68+
{
69+
// Allocate and create a singular VTable for this type projection.
70+
IDataObject.Vtbl* vtable = (IDataObject.Vtbl*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(CCW), sizeof(IDataObject.Vtbl));
71+
72+
// IUnknown
73+
vtable->QueryInterface_1 = &QueryInterface;
74+
vtable->AddRef_2 = &AddRef;
75+
vtable->Release_3 = &Release;
76+
vtable->GetData_4 = &GetData;
77+
vtable->GetDataHere_5 = &GetDataHere;
78+
vtable->QueryGetData_6 = &QueryGetData;
79+
vtable->GetCanonicalFormatEtc_7 = &GetCanonicalFormatEtc;
80+
vtable->SetData_8 = &SetData;
81+
vtable->EnumFormatEtc_9 = &EnumFormatEtc;
82+
vtable->DAdvise_10 = &DAdvise;
83+
vtable->DUnadvise_11 = &DUnadvise;
84+
vtable->EnumDAdvise_12 = &EnumDAdvise;
85+
return vtable;
86+
}
87+
88+
/// <summary>
89+
/// Creates a manual COM Callable Wrapper for the given <paramref name="object"/>.
90+
/// </summary>
91+
public static unsafe IDataObject* Create(DataObjectProxy @object) =>
92+
(IDataObject*)Lifetime.Allocate(@object, s_vtable);
93+
94+
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvStdcall)])]
95+
private static unsafe HRESULT QueryInterface(IDataObject* @this, Guid* iid, void** ppObject)
96+
{
97+
if (iid is null || ppObject is null)
98+
{
99+
return HRESULT.E_POINTER;
100+
}
101+
102+
if (iid->Equals(IDataObject.IID_Guid) || iid->Equals(IUnknown.IID_Guid))
103+
{
104+
*ppObject = @this;
105+
Lifetime.AddRef(@this);
106+
return HRESULT.S_OK;
107+
}
108+
109+
*ppObject = null;
110+
DataObjectProxy? proxy = Lifetime.GetObject(@this);
111+
if (proxy is null)
112+
{
113+
return HRESULT.E_NOINTERFACE;
114+
}
115+
116+
// Unwrap our "proxy" object by calling the the original object. This should roughly match the
117+
// OLE proxy behavior which returns it's own pointer for the IID_IUnknown and IID_IDataObject interfaces.
118+
return proxy._original->QueryInterface(iid, ppObject);
119+
}
120+
121+
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvStdcall)])]
122+
private static unsafe uint AddRef(IDataObject* @this) => Lifetime.AddRef(@this);
123+
124+
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvStdcall)])]
125+
private static unsafe uint Release(IDataObject* @this) => Lifetime.Release(@this);
126+
127+
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvStdcall)])]
128+
private static unsafe HRESULT GetData(IDataObject* @this, FORMATETC* pFormatetc, STGMEDIUM* pMedium) =>
129+
Lifetime.GetObject(@this)?.GetData(pFormatetc, pMedium) ?? HRESULT.COR_E_OBJECTDISPOSED;
130+
131+
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvStdcall)])]
132+
private static unsafe HRESULT GetDataHere(IDataObject* @this, FORMATETC* pFormatetc, STGMEDIUM* pMedium) =>
133+
Lifetime.GetObject(@this)?.GetDataHere(pFormatetc, pMedium) ?? HRESULT.COR_E_OBJECTDISPOSED;
134+
135+
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvStdcall)])]
136+
private static unsafe HRESULT QueryGetData(IDataObject* @this, FORMATETC* pFormatetc) =>
137+
Lifetime.GetObject(@this)?.QueryGetData(pFormatetc) ?? HRESULT.COR_E_OBJECTDISPOSED;
138+
139+
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvStdcall)])]
140+
private static unsafe HRESULT GetCanonicalFormatEtc(IDataObject* @this, FORMATETC* pFormatetcIn, FORMATETC* pFormatetcOut) =>
141+
Lifetime.GetObject(@this)?.GetCanonicalFormatEtc(pFormatetcIn, pFormatetcOut) ?? HRESULT.COR_E_OBJECTDISPOSED;
142+
143+
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvStdcall)])]
144+
private static unsafe HRESULT SetData(IDataObject* @this, FORMATETC* pFormatetc, STGMEDIUM* pMedium, BOOL fRelease) =>
145+
Lifetime.GetObject(@this)?.SetData(pFormatetc, pMedium, fRelease) ?? HRESULT.COR_E_OBJECTDISPOSED;
146+
147+
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvStdcall)])]
148+
private static unsafe HRESULT EnumFormatEtc(IDataObject* @this, uint dwDirection, IEnumFORMATETC** ppEnumFormatEtc) =>
149+
Lifetime.GetObject(@this)?.EnumFormatEtc(dwDirection, ppEnumFormatEtc) ?? HRESULT.COR_E_OBJECTDISPOSED;
150+
151+
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvStdcall)])]
152+
private static unsafe HRESULT DAdvise(IDataObject* @this, FORMATETC* pFormatetc, uint advf, IAdviseSink* pAdvSink, uint* pdwConnection) =>
153+
Lifetime.GetObject(@this)?.DAdvise(pFormatetc, advf, pAdvSink, pdwConnection) ?? HRESULT.COR_E_OBJECTDISPOSED;
154+
155+
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvStdcall)])]
156+
private static unsafe HRESULT DUnadvise(IDataObject* @this, uint dwConnection) =>
157+
Lifetime.GetObject(@this)?.DUnadvise(dwConnection) ?? HRESULT.COR_E_OBJECTDISPOSED;
158+
159+
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvStdcall)])]
160+
private static unsafe HRESULT EnumDAdvise(IDataObject* @this, IEnumSTATDATA** ppEnumAdvise) =>
161+
Lifetime.GetObject(@this)?.EnumDAdvise(ppEnumAdvise) ?? HRESULT.COR_E_OBJECTDISPOSED;
162+
}
163+
}

0 commit comments

Comments
 (0)