|  | 
|  | 1 | +import ctypes | 
|  | 2 | +import gc | 
|  | 3 | +import sys | 
|  | 4 | +import unittest | 
|  | 5 | +from ctypes import POINTER, byref, c_void_p | 
|  | 6 | +from ctypes.wintypes import BYTE, DWORD, WORD | 
|  | 7 | + | 
|  | 8 | +if sys.platform != "win32": | 
|  | 9 | +    raise unittest.SkipTest("Windows-specific test") | 
|  | 10 | + | 
|  | 11 | + | 
|  | 12 | +from _ctypes import COMError | 
|  | 13 | +from ctypes import HRESULT | 
|  | 14 | + | 
|  | 15 | + | 
|  | 16 | +COINIT_APARTMENTTHREADED = 0x2 | 
|  | 17 | +CLSCTX_SERVER = 5 | 
|  | 18 | +S_OK = 0 | 
|  | 19 | +OUT = 2 | 
|  | 20 | +TRUE = 1 | 
|  | 21 | +E_NOINTERFACE = -2147467262 | 
|  | 22 | + | 
|  | 23 | + | 
|  | 24 | +class GUID(ctypes.Structure): | 
|  | 25 | +    # https://learn.microsoft.com/en-us/windows/win32/api/guiddef/ns-guiddef-guid | 
|  | 26 | +    _fields_ = [ | 
|  | 27 | +        ("Data1", DWORD), | 
|  | 28 | +        ("Data2", WORD), | 
|  | 29 | +        ("Data3", WORD), | 
|  | 30 | +        ("Data4", BYTE * 8), | 
|  | 31 | +    ] | 
|  | 32 | + | 
|  | 33 | + | 
|  | 34 | +def create_proto_com_method(name, index, restype, *argtypes): | 
|  | 35 | +    proto = ctypes.WINFUNCTYPE(restype, *argtypes) | 
|  | 36 | + | 
|  | 37 | +    def make_method(*args): | 
|  | 38 | +        foreign_func = proto(index, name, *args) | 
|  | 39 | + | 
|  | 40 | +        def call(self, *args, **kwargs): | 
|  | 41 | +            return foreign_func(self, *args, **kwargs) | 
|  | 42 | + | 
|  | 43 | +        return call | 
|  | 44 | + | 
|  | 45 | +    return make_method | 
|  | 46 | + | 
|  | 47 | + | 
|  | 48 | +def create_guid(name): | 
|  | 49 | +    guid = GUID() | 
|  | 50 | +    # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-clsidfromstring | 
|  | 51 | +    ole32.CLSIDFromString(name, byref(guid)) | 
|  | 52 | +    return guid | 
|  | 53 | + | 
|  | 54 | + | 
|  | 55 | +def is_equal_guid(guid1, guid2): | 
|  | 56 | +    # https://learn.microsoft.com/en-us/windows/win32/api/objbase/nf-objbase-isequalguid | 
|  | 57 | +    return ole32.IsEqualGUID(byref(guid1), byref(guid2)) | 
|  | 58 | + | 
|  | 59 | + | 
|  | 60 | +ole32 = ctypes.oledll.ole32 | 
|  | 61 | + | 
|  | 62 | +IID_IUnknown = create_guid("{00000000-0000-0000-C000-000000000046}") | 
|  | 63 | +IID_IStream = create_guid("{0000000C-0000-0000-C000-000000000046}") | 
|  | 64 | +IID_IPersist = create_guid("{0000010C-0000-0000-C000-000000000046}") | 
|  | 65 | +CLSID_ShellLink = create_guid("{00021401-0000-0000-C000-000000000046}") | 
|  | 66 | + | 
|  | 67 | +# https://learn.microsoft.com/en-us/windows/win32/api/unknwn/nf-unknwn-iunknown-queryinterface(refiid_void) | 
|  | 68 | +proto_query_interface = create_proto_com_method( | 
|  | 69 | +    "QueryInterface", 0, HRESULT, POINTER(GUID), POINTER(c_void_p) | 
|  | 70 | +) | 
|  | 71 | +# https://learn.microsoft.com/en-us/windows/win32/api/unknwn/nf-unknwn-iunknown-addref | 
|  | 72 | +proto_add_ref = create_proto_com_method("AddRef", 1, ctypes.c_long) | 
|  | 73 | +# https://learn.microsoft.com/en-us/windows/win32/api/unknwn/nf-unknwn-iunknown-release | 
|  | 74 | +proto_release = create_proto_com_method("Release", 2, ctypes.c_long) | 
|  | 75 | +# https://learn.microsoft.com/en-us/windows/win32/api/objidl/nf-objidl-ipersist-getclassid | 
|  | 76 | +proto_get_class_id = create_proto_com_method( | 
|  | 77 | +    "GetClassID", 3, HRESULT, POINTER(GUID) | 
|  | 78 | +) | 
|  | 79 | + | 
|  | 80 | + | 
|  | 81 | +class ForeignFunctionsThatWillCallComMethodsTests(unittest.TestCase): | 
|  | 82 | +    def setUp(self): | 
|  | 83 | +        # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-coinitializeex | 
|  | 84 | +        ole32.CoInitializeEx(None, COINIT_APARTMENTTHREADED) | 
|  | 85 | + | 
|  | 86 | +    def tearDown(self): | 
|  | 87 | +        # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-couninitialize | 
|  | 88 | +        ole32.CoUninitialize() | 
|  | 89 | +        gc.collect() | 
|  | 90 | + | 
|  | 91 | +    @staticmethod | 
|  | 92 | +    def create_shelllink_persist(typ): | 
|  | 93 | +        ppst = typ() | 
|  | 94 | +        # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cocreateinstance | 
|  | 95 | +        ole32.CoCreateInstance( | 
|  | 96 | +            byref(CLSID_ShellLink), | 
|  | 97 | +            None, | 
|  | 98 | +            CLSCTX_SERVER, | 
|  | 99 | +            byref(IID_IPersist), | 
|  | 100 | +            byref(ppst), | 
|  | 101 | +        ) | 
|  | 102 | +        return ppst | 
|  | 103 | + | 
|  | 104 | +    def test_without_paramflags_and_iid(self): | 
|  | 105 | +        class IUnknown(c_void_p): | 
|  | 106 | +            QueryInterface = proto_query_interface() | 
|  | 107 | +            AddRef = proto_add_ref() | 
|  | 108 | +            Release = proto_release() | 
|  | 109 | + | 
|  | 110 | +        class IPersist(IUnknown): | 
|  | 111 | +            GetClassID = proto_get_class_id() | 
|  | 112 | + | 
|  | 113 | +        ppst = self.create_shelllink_persist(IPersist) | 
|  | 114 | + | 
|  | 115 | +        clsid = GUID() | 
|  | 116 | +        hr_getclsid = ppst.GetClassID(byref(clsid)) | 
|  | 117 | +        self.assertEqual(S_OK, hr_getclsid) | 
|  | 118 | +        self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid)) | 
|  | 119 | + | 
|  | 120 | +        self.assertEqual(2, ppst.AddRef()) | 
|  | 121 | +        self.assertEqual(3, ppst.AddRef()) | 
|  | 122 | + | 
|  | 123 | +        punk = IUnknown() | 
|  | 124 | +        hr_qi = ppst.QueryInterface(IID_IUnknown, punk) | 
|  | 125 | +        self.assertEqual(S_OK, hr_qi) | 
|  | 126 | +        self.assertEqual(3, punk.Release()) | 
|  | 127 | + | 
|  | 128 | +        with self.assertRaises(OSError) as e: | 
|  | 129 | +            punk.QueryInterface(IID_IStream, IUnknown()) | 
|  | 130 | +        self.assertEqual(E_NOINTERFACE, e.exception.winerror) | 
|  | 131 | + | 
|  | 132 | +        self.assertEqual(2, ppst.Release()) | 
|  | 133 | +        self.assertEqual(1, ppst.Release()) | 
|  | 134 | +        self.assertEqual(0, ppst.Release()) | 
|  | 135 | + | 
|  | 136 | +    def test_with_paramflags_and_without_iid(self): | 
|  | 137 | +        class IUnknown(c_void_p): | 
|  | 138 | +            QueryInterface = proto_query_interface(None) | 
|  | 139 | +            AddRef = proto_add_ref() | 
|  | 140 | +            Release = proto_release() | 
|  | 141 | + | 
|  | 142 | +        class IPersist(IUnknown): | 
|  | 143 | +            GetClassID = proto_get_class_id(((OUT, "pClassID"),)) | 
|  | 144 | + | 
|  | 145 | +        ppst = self.create_shelllink_persist(IPersist) | 
|  | 146 | + | 
|  | 147 | +        clsid = ppst.GetClassID() | 
|  | 148 | +        self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid)) | 
|  | 149 | + | 
|  | 150 | +        punk = IUnknown() | 
|  | 151 | +        hr_qi = ppst.QueryInterface(IID_IUnknown, punk) | 
|  | 152 | +        self.assertEqual(S_OK, hr_qi) | 
|  | 153 | +        self.assertEqual(1, punk.Release()) | 
|  | 154 | + | 
|  | 155 | +        with self.assertRaises(OSError) as e: | 
|  | 156 | +            ppst.QueryInterface(IID_IStream, IUnknown()) | 
|  | 157 | +        self.assertEqual(E_NOINTERFACE, e.exception.winerror) | 
|  | 158 | + | 
|  | 159 | +        self.assertEqual(0, ppst.Release()) | 
|  | 160 | + | 
|  | 161 | +    def test_with_paramflags_and_iid(self): | 
|  | 162 | +        class IUnknown(c_void_p): | 
|  | 163 | +            QueryInterface = proto_query_interface(None, IID_IUnknown) | 
|  | 164 | +            AddRef = proto_add_ref() | 
|  | 165 | +            Release = proto_release() | 
|  | 166 | + | 
|  | 167 | +        class IPersist(IUnknown): | 
|  | 168 | +            GetClassID = proto_get_class_id(((OUT, "pClassID"),), IID_IPersist) | 
|  | 169 | + | 
|  | 170 | +        ppst = self.create_shelllink_persist(IPersist) | 
|  | 171 | + | 
|  | 172 | +        clsid = ppst.GetClassID() | 
|  | 173 | +        self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid)) | 
|  | 174 | + | 
|  | 175 | +        punk = IUnknown() | 
|  | 176 | +        hr_qi = ppst.QueryInterface(IID_IUnknown, punk) | 
|  | 177 | +        self.assertEqual(S_OK, hr_qi) | 
|  | 178 | +        self.assertEqual(1, punk.Release()) | 
|  | 179 | + | 
|  | 180 | +        with self.assertRaises(COMError) as e: | 
|  | 181 | +            ppst.QueryInterface(IID_IStream, IUnknown()) | 
|  | 182 | +        self.assertEqual(E_NOINTERFACE, e.exception.hresult) | 
|  | 183 | + | 
|  | 184 | +        self.assertEqual(0, ppst.Release()) | 
|  | 185 | + | 
|  | 186 | + | 
|  | 187 | +if __name__ == '__main__': | 
|  | 188 | +    unittest.main() | 
0 commit comments