From 24ee255e166c576008a9f55c82d42f4a1a028fee Mon Sep 17 00:00:00 2001 From: furkanonder Date: Wed, 10 Sep 2025 22:18:13 +0300 Subject: [PATCH 1/2] Enhance shelve serializer validation with descriptive error messages --- Lib/shelve.py | 14 ++++++++ Lib/test/test_shelve.py | 80 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 92 insertions(+), 2 deletions(-) diff --git a/Lib/shelve.py b/Lib/shelve.py index 1010be1e09d702..dc293deb1b5a8a 100644 --- a/Lib/shelve.py +++ b/Lib/shelve.py @@ -106,6 +106,19 @@ def __init__(self, dict, protocol=None, writeback=False, self.serializer = serializer self.deserializer = deserializer + @staticmethod + def _validate_serialized_value(serialized_value, original_value): + if (serialized_value is None or + not isinstance(serialized_value, (bytes, str))): + if serialized_value is None: + invalid_type = "None" + else: + invalid_type = type(serialized_value).__name__ + msg = (f"Serializer returned {invalid_type} for value " + f"{original_value!r} But database values must be " + f"bytes or str, not {invalid_type}") + raise ShelveError(msg) + def __iter__(self): for k in self.dict.keys(): yield k.decode(self.keyencoding) @@ -135,6 +148,7 @@ def __setitem__(self, key, value): if self.writeback: self.cache[key] = value serialized_value = self.serializer(value, self._protocol) + self._validate_serialized_value(serialized_value, value) self.dict[key.encode(self.keyencoding)] = serialized_value def __delitem__(self, key): diff --git a/Lib/test/test_shelve.py b/Lib/test/test_shelve.py index 64609ab9dd9a62..b0a0dbf3baaf21 100644 --- a/Lib/test/test_shelve.py +++ b/Lib/test/test_shelve.py @@ -173,6 +173,8 @@ def test_custom_serializer_and_deserializer(self): def serializer(obj, protocol): if isinstance(obj, (bytes, bytearray, str)): if protocol == 5: + if isinstance(obj, bytearray): + return bytes(obj) # DBM backends expect bytes return obj return type(obj).__name__ elif isinstance(obj, array.array): @@ -223,11 +225,10 @@ def deserializer(data): ) def test_custom_incomplete_serializer_and_deserializer(self): - dbm_sqlite3 = import_helper.import_module("dbm.sqlite3") os.mkdir(self.dirname) self.addCleanup(os_helper.rmtree, self.dirname) - with self.assertRaises(dbm_sqlite3.error): + with self.assertRaises(shelve.ShelveError): def serializer(obj, protocol=None): pass @@ -430,6 +431,81 @@ def setUp(self): dbm._defaultmod = self.dbm_mod +class TestShelveValidation(unittest.TestCase): + dirname = os_helper.TESTFN + fname = os.path.join(dirname, os_helper.TESTFN) + + def setup_test_dir(self): + os_helper.rmtree(self.dirname) + os.mkdir(self.dirname) + + def setUp(self): + self.addCleanup(setattr, dbm, "_defaultmod", dbm._defaultmod) + os.mkdir(self.dirname) + self.addCleanup(os_helper.rmtree, self.dirname) + + def test_serializer_unsupported_return_type(self): + def int_serializer(obj, protocol=None): + return 3 + + def none_serializer(obj, protocol=None): + return None + + def deserializer(data): + if isinstance(data, bytes): + return data.decode("utf-8") + else: + return data + + for module in dbm_iterator(): + self.setup_test_dir() + dbm._defaultmod = module + with module.open(self.fname, "c"): + pass + self.assertEqual(module.__name__, dbm.whichdb(self.fname)) + + with shelve.open(self.fname, serializer=none_serializer, + deserializer=deserializer) as s: + with self.assertRaises(shelve.ShelveError) as cm: + s["key"] = "value" + self.assertEqual("Serializer returned None for value 'value' " + "But database values must be bytes or str, not None", + f"{cm.exception}") + + with shelve.open(self.fname, serializer=int_serializer, + deserializer=deserializer,) as s: + with self.assertRaises(shelve.ShelveError) as cm: + s["key"] = "value" + self.assertEqual("Serializer returned int for value 'value' " + "But database values must be bytes or str, not int", + f"{cm.exception}") + + def test_shelve_type_compatibility(self): + for module in dbm_iterator(): + self.setup_test_dir() + dbm._defaultmod = module + with shelve.Shelf(module.open(self.fname, "c")) as shelf: + shelf["string"] = "hello" + shelf["bytes"] = b"world" + shelf["number"] = 42 + shelf["list"] = [1, 2, 3] + shelf["dict"] = {"key": "value"} + shelf["set"] = {1, 2, 3} + shelf["tuple"] = (1, 2, 3) + shelf["complex"] = 1 + 2j + shelf["bytearray"] = bytearray(b"test") + shelf["array"] = array.array("i", [1, 2, 3]) + self.assertEqual(shelf["string"], "hello") + self.assertEqual(shelf["bytes"], b"world") + self.assertEqual(shelf["number"], 42) + self.assertEqual(shelf["list"], [1, 2, 3]) + self.assertEqual(shelf["dict"], {"key": "value"}) + self.assertEqual(shelf["set"], {1, 2, 3}) + self.assertEqual(shelf["tuple"], (1, 2, 3)) + self.assertEqual(shelf["complex"], 1 + 2j) + self.assertEqual(shelf["bytearray"], bytearray(b"test")) + self.assertEqual(shelf["array"], array.array("i", [1, 2, 3])) + from test import mapping_tests for proto in range(pickle.HIGHEST_PROTOCOL + 1): From 563e204faf6376e34651b94ffb1422ce8a96b18e Mon Sep 17 00:00:00 2001 From: furkanonder Date: Mon, 22 Sep 2025 23:28:57 +0300 Subject: [PATCH 2/2] Enhance dbm module error messages with descriptive type information --- Modules/_dbmmodule.c | 10 ++++++---- Modules/_gdbmmodule.c | 9 +++++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/Modules/_dbmmodule.c b/Modules/_dbmmodule.c index 0cd0f043de453d..68002e05d7ab6f 100644 --- a/Modules/_dbmmodule.c +++ b/Modules/_dbmmodule.c @@ -236,8 +236,9 @@ dbm_ass_sub_lock_held(PyObject *self, PyObject *v, PyObject *w) dbmobject *dp = dbmobject_CAST(self); if ( !PyArg_Parse(v, "s#", &krec.dptr, &tmp_size) ) { - PyErr_SetString(PyExc_TypeError, - "dbm mappings have bytes or string keys only"); + PyErr_Format(PyExc_TypeError, + "dbm key returned %.100s for value %R But database keys must be bytes or str, not %.100s", + Py_TYPE(v)->tp_name, v, Py_TYPE(v)->tp_name); return -1; } _dbm_state *state = PyType_GetModuleState(Py_TYPE(dp)); @@ -263,8 +264,9 @@ dbm_ass_sub_lock_held(PyObject *self, PyObject *v, PyObject *w) } } else { if ( !PyArg_Parse(w, "s#", &drec.dptr, &tmp_size) ) { - PyErr_SetString(PyExc_TypeError, - "dbm mappings have bytes or string elements only"); + PyErr_Format(PyExc_TypeError, + "dbm value returned %.100s for value %R But database values must be bytes or str, not %.100s", + Py_TYPE(w)->tp_name, w, Py_TYPE(w)->tp_name); return -1; } drec.dsize = tmp_size; diff --git a/Modules/_gdbmmodule.c b/Modules/_gdbmmodule.c index 6a4939512b22fc..40f25e74789b2e 100644 --- a/Modules/_gdbmmodule.c +++ b/Modules/_gdbmmodule.c @@ -248,7 +248,7 @@ parse_datum(PyObject *o, datum *d, const char *failmsg) Py_ssize_t size; if (!PyArg_Parse(o, "s#", &d->dptr, &size)) { if (failmsg != NULL) { - PyErr_SetString(PyExc_TypeError, failmsg); + PyErr_Format(PyExc_TypeError, failmsg, Py_TYPE(o)->tp_name, o, Py_TYPE(o)->tp_name); } return 0; } @@ -324,11 +324,12 @@ static int gdbm_ass_sub_lock_held(PyObject *op, PyObject *v, PyObject *w) { datum krec, drec; - const char *failmsg = "gdbm mappings have bytes or string indices only"; + const char *key_failmsg = "dbm key returned %.100s for value %R But database keys must be bytes or str, not %.100s"; + const char *value_failmsg = "dbm value returned %.100s for value %R But database keys must be bytes or str, not %.100s"; gdbmobject *dp = _gdbmobject_CAST(op); _gdbm_state *state = PyType_GetModuleState(Py_TYPE(dp)); - if (!parse_datum(v, &krec, failmsg)) { + if (!parse_datum(v, &krec, key_failmsg)) { return -1; } if (dp->di_dbm == NULL) { @@ -349,7 +350,7 @@ gdbm_ass_sub_lock_held(PyObject *op, PyObject *v, PyObject *w) } } else { - if (!parse_datum(w, &drec, failmsg)) { + if (!parse_datum(w, &drec, value_failmsg)) { return -1; } errno = 0;