diff --git a/Lib/test/test_shelve.py b/Lib/test/test_shelve.py index 64609ab9dd9a62..5f6a030e018f96 100644 --- a/Lib/test/test_shelve.py +++ b/Lib/test/test_shelve.py @@ -5,7 +5,7 @@ import pickle import os -from test.support import import_helper, os_helper +from test.support import import_helper, os_helper, subTests from collections.abc import MutableMapping from test.test_dbm import dbm_iterator @@ -173,6 +173,8 @@ def test_custom_serializer_and_deserializer(self): def serializer(obj, protocol): if isinstance(obj, (bytes, bytearray, str)): if protocol == 5: + if isinstance(obj, bytearray): + return bytes(obj) # DBM backends expect bytes return obj return type(obj).__name__ elif isinstance(obj, array.array): @@ -222,22 +224,31 @@ def deserializer(data): s["array_data"], array_data.tobytes().decode() ) - def test_custom_incomplete_serializer_and_deserializer(self): - dbm_sqlite3 = import_helper.import_module("dbm.sqlite3") - os.mkdir(self.dirname) - self.addCleanup(os_helper.rmtree, self.dirname) + @subTests("serialized", [None, ["invalid type"]]) + def test_custom_invalid_serializer(self, serialized): + test_dir = f"{self.dirname}_{id(serialized)}" + os.mkdir(test_dir) + self.addCleanup(os_helper.rmtree, test_dir) + test_fn = os.path.join(test_dir, "shelftemp.db") - with self.assertRaises(dbm_sqlite3.error): - def serializer(obj, protocol=None): - pass + def serializer(obj, protocol=None): + return serialized - def deserializer(data): - return data.decode("utf-8") + def deserializer(data): + return data.decode("utf-8") - with shelve.open(self.fn, serializer=serializer, + # Since the serializer returns an invalid type or None, + # dbm.error is raised by dbm.sqlite3 and TypeError is raised + # by other backends. + with self.assertRaises((TypeError, dbm.error)): + with shelve.open(test_fn, serializer=serializer, deserializer=deserializer) as s: s["foo"] = "bar" + def test_custom_incomplete_deserializer(self): + os.mkdir(self.dirname) + self.addCleanup(os_helper.rmtree, self.dirname) + def serializer(obj, protocol=None): return type(obj).__name__.encode("utf-8") @@ -352,7 +363,7 @@ def type_name_len(obj): self.assertEqual(s["bytearray_data"], "bytearray") self.assertEqual(s["array_data"], "array") - def test_custom_incomplete_serializer_and_deserializer_bsd_db_shelf(self): + def test_custom_incomplete_deserializer_bsd_db_shelf(self): berkeleydb = import_helper.import_module("berkeleydb") os.mkdir(self.dirname) self.addCleanup(os_helper.rmtree, self.dirname) @@ -370,6 +381,11 @@ def deserializer(data): self.assertIsNone(s["foo"]) self.assertNotEqual(s["foo"], "bar") + def test_custom_incomplete_serializer_bsd_db_shelf(self): + berkeleydb = import_helper.import_module("berkeleydb") + os.mkdir(self.dirname) + self.addCleanup(os_helper.rmtree, self.dirname) + def serializer(obj, protocol=None): pass