Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 115 additions & 17 deletions Lib/test/test_ctypes/test_win32_com_foreign_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
raise unittest.SkipTest("Windows-specific test")


from _ctypes import COMError
from _ctypes import COMError, CopyComPointer
from ctypes import HRESULT


Expand Down Expand Up @@ -78,6 +78,19 @@ def is_equal_guid(guid1, guid2):
)


def create_shelllink_persist(typ):
ppst = typ()
# https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cocreateinstance
ole32.CoCreateInstance(
byref(CLSID_ShellLink),
None,
CLSCTX_SERVER,
byref(IID_IPersist),
byref(ppst),
)
return ppst


class ForeignFunctionsThatWillCallComMethodsTests(unittest.TestCase):
def setUp(self):
# https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-coinitializeex
Expand All @@ -88,19 +101,6 @@ def tearDown(self):
ole32.CoUninitialize()
gc.collect()

@staticmethod
def create_shelllink_persist(typ):
ppst = typ()
# https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cocreateinstance
ole32.CoCreateInstance(
byref(CLSID_ShellLink),
None,
CLSCTX_SERVER,
byref(IID_IPersist),
byref(ppst),
)
return ppst

def test_without_paramflags_and_iid(self):
class IUnknown(c_void_p):
QueryInterface = proto_query_interface()
Expand All @@ -110,7 +110,7 @@ class IUnknown(c_void_p):
class IPersist(IUnknown):
GetClassID = proto_get_class_id()

ppst = self.create_shelllink_persist(IPersist)
ppst = create_shelllink_persist(IPersist)

clsid = GUID()
hr_getclsid = ppst.GetClassID(byref(clsid))
Expand Down Expand Up @@ -142,7 +142,7 @@ class IUnknown(c_void_p):
class IPersist(IUnknown):
GetClassID = proto_get_class_id(((OUT, "pClassID"),))

ppst = self.create_shelllink_persist(IPersist)
ppst = create_shelllink_persist(IPersist)

clsid = ppst.GetClassID()
self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))
Expand All @@ -167,7 +167,7 @@ class IUnknown(c_void_p):
class IPersist(IUnknown):
GetClassID = proto_get_class_id(((OUT, "pClassID"),), IID_IPersist)

ppst = self.create_shelllink_persist(IPersist)
ppst = create_shelllink_persist(IPersist)

clsid = ppst.GetClassID()
self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))
Expand All @@ -184,5 +184,103 @@ class IPersist(IUnknown):
self.assertEqual(0, ppst.Release())


class CopyComPointerTests(unittest.TestCase):
def setUp(self):
ole32.CoInitializeEx(None, COINIT_APARTMENTTHREADED)

class IUnknown(c_void_p):
QueryInterface = proto_query_interface(None, IID_IUnknown)
AddRef = proto_add_ref()
Release = proto_release()

class IPersist(IUnknown):
GetClassID = proto_get_class_id(((OUT, "pClassID"),), IID_IPersist)

self.IUnknown = IUnknown
self.IPersist = IPersist

def tearDown(self):
ole32.CoUninitialize()
gc.collect()

def test_both_are_null(self):
src = self.IPersist()
dst = self.IPersist()

hr = CopyComPointer(src, byref(dst))

self.assertEqual(S_OK, hr)

self.assertIsNone(src.value)
self.assertIsNone(dst.value)

def test_src_is_nonnull_and_dest_is_null(self):
# The reference count of the COM pointer created by `CoCreateInstance`
# is initially 1.
src = create_shelllink_persist(self.IPersist)
dst = self.IPersist()

# `CopyComPointer` calls `AddRef` explicitly in the C implementation.
# The refcount of `src` is incremented from 1 to 2 here.
hr = CopyComPointer(src, byref(dst))

self.assertEqual(S_OK, hr)
self.assertEqual(src.value, dst.value)

# This indicates that the refcount was 2 before the `Release` call.
self.assertEqual(1, src.Release())

clsid = dst.GetClassID()
self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))

self.assertEqual(0, dst.Release())

def test_src_is_null_and_dest_is_nonnull(self):
src = self.IPersist()
dst_orig = create_shelllink_persist(self.IPersist)
dst = self.IPersist()
CopyComPointer(dst_orig, byref(dst))
self.assertEqual(1, dst_orig.Release())

clsid = dst.GetClassID()
self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))

# This does NOT affects the refcount of `dst_orig`.
hr = CopyComPointer(src, byref(dst))

self.assertEqual(S_OK, hr)
self.assertIsNone(dst.value)

with self.assertRaises(ValueError):
dst.GetClassID() # NULL COM pointer access

# This indicates that the refcount was 1 before the `Release` call.
self.assertEqual(0, dst_orig.Release())

def test_both_are_nonnull(self):
src = create_shelllink_persist(self.IPersist)
dst_orig = create_shelllink_persist(self.IPersist)
dst = self.IPersist()
CopyComPointer(dst_orig, byref(dst))
self.assertEqual(1, dst_orig.Release())

self.assertEqual(dst.value, dst_orig.value)
self.assertNotEqual(src.value, dst.value)

hr = CopyComPointer(src, byref(dst))

self.assertEqual(S_OK, hr)
self.assertEqual(src.value, dst.value)
self.assertNotEqual(dst.value, dst_orig.value)

self.assertEqual(1, src.Release())

clsid = dst.GetClassID()
self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))

self.assertEqual(0, dst.Release())
self.assertEqual(0, dst_orig.Release())


if __name__ == '__main__':
unittest.main()
Loading