diff --git a/Lib/test/test_ctypes/test_pointers.py b/Lib/test/test_ctypes/test_pointers.py index a8d243a45de0f4..792a3687dba93b 100644 --- a/Lib/test/test_ctypes/test_pointers.py +++ b/Lib/test/test_ctypes/test_pointers.py @@ -5,7 +5,7 @@ import unittest from ctypes import (CDLL, CFUNCTYPE, Structure, POINTER, pointer, _Pointer, - byref, sizeof, + addressof, byref, sizeof, c_void_p, c_char_p, c_byte, c_ubyte, c_short, c_ushort, c_int, c_uint, c_long, c_ulong, c_longlong, c_ulonglong, @@ -472,6 +472,105 @@ class C(Structure): ptr.set_type(c_int) self.assertIs(ptr._type_, c_int) + def test_pointer_lifecycle_basic(self): + i = c_long(1010) + p = pointer(i) + self.assertEqual(p[0], 1010) + self.assertIsNone(p._b_base_) + self.assertEqual(addressof(i), addressof(p.contents)) + + def test_pointer_lifecycle_set_contents(self): + i = c_long(2020) + p = pointer(c_long(1010)) + p.contents = i + self.assertEqual(p[0], 2020) + self.assertIsNone(p._b_base_) + self.assertEqual(addressof(i), addressof(p.contents)) + + def test_pointer_lifecycle_set_pointer_contents(self): + i = c_long(3030) + p = pointer(c_long(1010)) + pointer(p).contents.contents = i + self.assertEqual(p.contents.value, 3030) + self.assertEqual(addressof(i), addressof(p.contents)) + + def test_pointer_lifecycle_array_set_contents(self): + arr_type = POINTER(c_long) * 3 + arr_obj = arr_type() + i = c_long(300300) + arr_obj[0] = pointer(c_long(100100)) + arr_obj[1] = pointer(c_long(200200)) + arr_obj[2] = pointer(i) + self.assertEqual(arr_obj[0].contents.value, 100100) + self.assertEqual(arr_obj[1].contents.value, 200200) + self.assertEqual(arr_obj[2].contents.value, 300300) + self.assertEqual(addressof(i), addressof(arr_obj[2].contents)) + + def test_pointer_lifecycle_array_set_pointer_contents(self): + arr_type = POINTER(c_long) * 3 + arr_obj = arr_type() + i = c_long(200003) + arr_obj[0].contents = c_long(100001) + arr_obj[1].contents = c_long(200002) + arr_obj[2].contents = i + self.assertEqual(arr_obj[0].contents.value, 100001) + self.assertEqual(arr_obj[1].contents.value, 200002) + self.assertEqual(arr_obj[2].contents.value, 200003) + self.assertEqual(addressof(i), addressof(arr_obj[2].contents)) + + def test_pointer_lifecycle_array_set_pointer_contents_pointer(self): + arr_type = POINTER(c_long) * 3 + arr_obj = arr_type() + i = c_long(200003) + pointer(arr_obj[0]).contents.contents = c_long(100001) + pointer(arr_obj[1]).contents.contents = c_long(200002) + pointer(arr_obj[2]).contents.contents = i + self.assertEqual(arr_obj[0].contents.value, 100001) + self.assertEqual(arr_obj[1].contents.value, 200002) + self.assertEqual(arr_obj[2].contents.value, 200003) + self.assertEqual(addressof(i), addressof(arr_obj[2].contents)) + + def test_pointer_lifecycle_struct_set_contents(self): + class S(Structure): + _fields_ = (("s", POINTER(c_long)),) + s = S(s=pointer(c_long(1111111))) + s.s.contents = c_long(2222222) + self.assertEqual(s.s.contents.value, 2222222) + + def test_pointer_lifecycle_struct_set_contents_pointer(self): + class S(Structure): + _fields_ = (("s", POINTER(c_long)),) + s = S(s=pointer(c_long(1111111))) + pointer(s.s).contents.contents = c_long(2222222) + self.assertEqual(s.s.contents.value, 2222222) + + def test_pointer_lifecycle_struct_set_pointer_contents(self): + class S(Structure): + _fields_ = (("s", POINTER(c_long)),) + s = S(s=pointer(c_long(1111111))) + s.s = pointer(c_long(3333333)) + self.assertEqual(s.s.contents.value, 3333333) + + def test_pointer_lifecycle_struct_with_extra_field(self): + class U(Structure): + _fields_ = ( + ("s", POINTER(c_long)), + ("u", c_long), + ) + u = U(s=pointer(c_long(1010101))) + u.s.contents = c_long(202020202) + self.assertEqual(u.s.contents.value, 202020202) + + def test_pointer_lifecycle_struct_with_extra_field_pointer(self): + class U(Structure): + _fields_ = ( + ("s", POINTER(c_uint)), + ("u", c_uint), + ) + u = U(s=pointer(c_uint(1010101))) + pointer(u.s).contents.contents = c_uint(202020202) + self.assertEqual(u.s.contents.value, 202020202) + if __name__ == '__main__': unittest.main() diff --git a/Modules/_ctypes/_ctypes.c b/Modules/_ctypes/_ctypes.c index 4bd3e380b3bc4b..6f46234c23a34b 100644 --- a/Modules/_ctypes/_ctypes.c +++ b/Modules/_ctypes/_ctypes.c @@ -2958,7 +2958,7 @@ KeepRef_lock_held(CDataObject *target, Py_ssize_t index, PyObject *keep) CDataObject *ob; PyObject *key; -/* Optimization: no need to store None */ + /* Optimization: no need to store None */ if (keep == Py_None) { Py_DECREF(Py_None); return 0; @@ -5724,8 +5724,34 @@ Pointer_set_contents_lock_held(PyObject *op, PyObject *value, void *closure) pointer instance has b_length set to 2 instead of 1, and we set 'value' itself as the second item of the b_objects list, additionally. */ + + CDataObject * root = self->b_base; + /* perhaps, this is a bit excessive: if we have are in a chain of pointers + that starts with non-pointer (e.g. a union), can we consider the current + pointer to be "detached" from this chain? */ + while (root != NULL && root->b_base != NULL) { + root = root->b_base; + } + + /* If the b_base is NULL now or if we are a part of chain of pointers fully + modeled within ctypes, AND the value is a pointer, array, struct or union, + we just override the b_base. */ + if ((root == NULL || PyType_IsSubtype(Py_TYPE(root), st->PyCPointer_Type)) && + (PyType_IsSubtype(Py_TYPE(value), st->PyCPointer_Type) || + PyType_IsSubtype(Py_TYPE(value), st->PyCArray_Type) || + PyType_IsSubtype(Py_TYPE(value), st->Struct_Type) || + PyType_IsSubtype(Py_TYPE(value), st->Union_Type)) + ) { + Py_XSETREF(self->b_base, (CDataObject *) Py_NewRef(value)); + return 0; // no need to add `value` to `keep` objects - it's in b_base + } + + /* If we are a part of chain of pointers that is not fully modeled within + ctypes, (or modeled in a complex way, e.g., with arrays and structures), + then everything should be covered by keepref logic bellow */ + Py_INCREF(value); - if (-1 == KeepRef(self, 1, value)) + if (-1 == KeepRef_lock_held(self, 1, value)) return -1; keep = GetKeepedObjects(dst); @@ -5733,7 +5759,7 @@ Pointer_set_contents_lock_held(PyObject *op, PyObject *value, void *closure) return -1; Py_INCREF(keep); - return KeepRef(self, 0, keep); + return KeepRef_lock_held(self, 0, keep); } static int