diff --git a/Lib/pickle.py b/Lib/pickle.py index ed8138beb908ee..3c1a6c1de20acb 100644 --- a/Lib/pickle.py +++ b/Lib/pickle.py @@ -548,10 +548,11 @@ def save(self, obj, save_persistent_id=True): self.framer.commit_frame() # Check for persistent id (defined by a subclass) - pid = self.persistent_id(obj) - if pid is not None and save_persistent_id: - self.save_pers(pid) - return + if save_persistent_id: + pid = self.persistent_id(obj) + if pid is not None: + self.save_pers(pid) + return # Check the memo x = self.memo.get(id(obj)) diff --git a/Lib/test/test_pickle.py b/Lib/test/test_pickle.py index c84e507cdf645f..95dde0520aef9a 100644 --- a/Lib/test/test_pickle.py +++ b/Lib/test/test_pickle.py @@ -244,6 +244,30 @@ def persistent_load(subself, pid): unpickler = PersUnpickler(io.BytesIO(self.dumps('abc', proto))) self.assertEqual(unpickler.load(), 'abc') + def test_pickler_instance_attribute(self): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + f = io.BytesIO() + pickler = self.pickler(f, proto) + called = [] + def persistent_id(obj): + called.append(obj) + return obj + pickler.persistent_id = persistent_id + pickler.dump('abc') + self.assertEqual(called, ['abc']) + self.assertEqual(self.loads(f.getvalue()), 'abc') + + def test_unpickler_instance_attribute(self): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + unpickler = self.unpickler(io.BytesIO(self.dumps('abc', proto))) + called = [] + def persistent_load(pid): + called.append(pid) + return pid + unpickler.persistent_load = persistent_load + self.assertEqual(unpickler.load(), 'abc') + self.assertEqual(called, ['abc']) + class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests, unittest.TestCase): pickler_class = pickle._Pickler @@ -368,17 +392,20 @@ class SizeofTests(unittest.TestCase): def test_pickler(self): basesize = support.calcobjsize('6P2n3i2n3i2P') + P = struct.calcsize('P') p = _pickle.Pickler(io.BytesIO()) self.assertEqual(object.__sizeof__(p), basesize) MT_size = struct.calcsize('3nP0n') ME_size = struct.calcsize('Pn0P') check = self.check_sizeof check(p, basesize + + 2 * P + # Managed dict MT_size + 8 * ME_size + # Minimal memo table size. sys.getsizeof(b'x'*4096)) # Minimal write buffer size. for i in range(6): p.dump(chr(i)) check(p, basesize + + 2 * P + # Managed dict MT_size + 32 * ME_size + # Size of memo table required to # save references to 6 objects. 0) # Write buffer is cleared after every dump(). @@ -395,6 +422,7 @@ def test_unpickler(self): encoding=encoding, errors=errors) self.assertEqual(object.__sizeof__(u), basesize) check(u, basesize + + 2 * P + # Managed dict 32 * P + # Minimal memo table size. len(encoding) + 1 + len(errors) + 1) @@ -404,7 +432,7 @@ def check_unpickler(data, memo_size, marks_size): u = unpickler(io.BytesIO(dump), encoding='ASCII', errors='strict') u.load() - check(u, stdsize + memo_size * P + marks_size * n) + check(u, stdsize + 2 * P + memo_size * P + marks_size * n) check_unpickler(0, 32, 0) # 20 is minimal non-empty mark stack size. @@ -427,7 +455,7 @@ def recurse(deep): u = unpickler(io.BytesIO(pickle.dumps('a', 0)), encoding='ASCII', errors='strict') u.load() - check(u, stdsize + 32 * P + 2 + 1) + check(u, stdsize + 2 * P + 32 * P + 2 + 1) ALT_IMPORT_MAPPING = { diff --git a/Misc/NEWS.d/next/Library/2024-10-19-11-06-06.gh-issue-125631.BlhVvR.rst b/Misc/NEWS.d/next/Library/2024-10-19-11-06-06.gh-issue-125631.BlhVvR.rst new file mode 100644 index 00000000000000..e870abbf87803a --- /dev/null +++ b/Misc/NEWS.d/next/Library/2024-10-19-11-06-06.gh-issue-125631.BlhVvR.rst @@ -0,0 +1,4 @@ +Restore ability to set :attr:`~pickle.Pickler.persistent_id` and +:attr:`~pickle.Unpickler.persistent_load` attributes of instances of the +:class:`!Pickler` and :class:`!Unpickler` classes in the :mod:`pickle` +module. diff --git a/Modules/_pickle.c b/Modules/_pickle.c index b2bd9545c1b130..4a27ebb9ae22d6 100644 --- a/Modules/_pickle.c +++ b/Modules/_pickle.c @@ -5120,7 +5120,7 @@ static PyType_Spec pickler_type_spec = { .name = "_pickle.Pickler", .basicsize = sizeof(PicklerObject), .flags = (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC | - Py_TPFLAGS_IMMUTABLETYPE), + Py_TPFLAGS_IMMUTABLETYPE | Py_TPFLAGS_MANAGED_DICT), .slots = pickler_type_slots, }; @@ -7585,7 +7585,7 @@ static PyType_Spec unpickler_type_spec = { .name = "_pickle.Unpickler", .basicsize = sizeof(UnpicklerObject), .flags = (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC | - Py_TPFLAGS_IMMUTABLETYPE), + Py_TPFLAGS_IMMUTABLETYPE | Py_TPFLAGS_MANAGED_DICT), .slots = unpickler_type_slots, };