Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions Lib/shelve.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,19 @@ def __init__(self, dict, protocol=None, writeback=False,
self.serializer = serializer
self.deserializer = deserializer

@staticmethod
def _validate_serialized_value(serialized_value, original_value):
Copy link
Member

@picnixz picnixz Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Validation is also at the level of the C extension (in dbmmodule.c) so you should also change the message out there.

if (serialized_value is None or
not isinstance(serialized_value, (bytes, str))):
if serialized_value is None:
invalid_type = "None"
else:
invalid_type = type(serialized_value).__name__
msg = (f"Serializer returned {invalid_type} for value "
f"{original_value!r} But database values must be "
Copy link
Contributor

@aisk aisk Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick, 'But' should be lowercase unless it follows a period:

Suggested change
f"{original_value!r} But database values must be "
f"{original_value!r}, but database values must be "

Adding a period before it is another option.

f"bytes or str, not {invalid_type}")
Comment on lines +117 to +119
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

invalid_type is repeated twice. This is unnecessary.

The repr of original_value can be very large, it should not be included in the error message. Error most likely does not depend on the original value.

raise ShelveError(msg)

def __iter__(self):
for k in self.dict.keys():
yield k.decode(self.keyencoding)
Expand Down Expand Up @@ -135,6 +148,7 @@ def __setitem__(self, key, value):
if self.writeback:
self.cache[key] = value
serialized_value = self.serializer(value, self._protocol)
self._validate_serialized_value(serialized_value, value)
self.dict[key.encode(self.keyencoding)] = serialized_value

def __delitem__(self, key):
Expand Down
80 changes: 78 additions & 2 deletions Lib/test/test_shelve.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ def test_custom_serializer_and_deserializer(self):
def serializer(obj, protocol):
if isinstance(obj, (bytes, bytearray, str)):
if protocol == 5:
if isinstance(obj, bytearray):
return bytes(obj) # DBM backends expect bytes
return obj
return type(obj).__name__
elif isinstance(obj, array.array):
Expand Down Expand Up @@ -223,11 +225,10 @@ def deserializer(data):
)

def test_custom_incomplete_serializer_and_deserializer(self):
dbm_sqlite3 = import_helper.import_module("dbm.sqlite3")
os.mkdir(self.dirname)
self.addCleanup(os_helper.rmtree, self.dirname)

with self.assertRaises(dbm_sqlite3.error):
with self.assertRaises(shelve.ShelveError):
def serializer(obj, protocol=None):
pass

Expand Down Expand Up @@ -430,6 +431,81 @@ def setUp(self):
dbm._defaultmod = self.dbm_mod


class TestShelveValidation(unittest.TestCase):
dirname = os_helper.TESTFN
fname = os.path.join(dirname, os_helper.TESTFN)

def setup_test_dir(self):
os_helper.rmtree(self.dirname)
os.mkdir(self.dirname)

def setUp(self):
self.addCleanup(setattr, dbm, "_defaultmod", dbm._defaultmod)
os.mkdir(self.dirname)
self.addCleanup(os_helper.rmtree, self.dirname)

def test_serializer_unsupported_return_type(self):
def int_serializer(obj, protocol=None):
return 3

def none_serializer(obj, protocol=None):
return None

def deserializer(data):
if isinstance(data, bytes):
return data.decode("utf-8")
else:
return data

for module in dbm_iterator():
self.setup_test_dir()
dbm._defaultmod = module
with module.open(self.fname, "c"):
pass
self.assertEqual(module.__name__, dbm.whichdb(self.fname))

with shelve.open(self.fname, serializer=none_serializer,
deserializer=deserializer) as s:
with self.assertRaises(shelve.ShelveError) as cm:
s["key"] = "value"
self.assertEqual("Serializer returned None for value 'value' "
"But database values must be bytes or str, not None",
f"{cm.exception}")

with shelve.open(self.fname, serializer=int_serializer,
deserializer=deserializer,) as s:
with self.assertRaises(shelve.ShelveError) as cm:
s["key"] = "value"
self.assertEqual("Serializer returned int for value 'value' "
"But database values must be bytes or str, not int",
f"{cm.exception}")

def test_shelve_type_compatibility(self):
for module in dbm_iterator():
self.setup_test_dir()
dbm._defaultmod = module
with shelve.Shelf(module.open(self.fname, "c")) as shelf:
shelf["string"] = "hello"
shelf["bytes"] = b"world"
shelf["number"] = 42
shelf["list"] = [1, 2, 3]
shelf["dict"] = {"key": "value"}
shelf["set"] = {1, 2, 3}
shelf["tuple"] = (1, 2, 3)
shelf["complex"] = 1 + 2j
shelf["bytearray"] = bytearray(b"test")
shelf["array"] = array.array("i", [1, 2, 3])
self.assertEqual(shelf["string"], "hello")
self.assertEqual(shelf["bytes"], b"world")
self.assertEqual(shelf["number"], 42)
self.assertEqual(shelf["list"], [1, 2, 3])
self.assertEqual(shelf["dict"], {"key": "value"})
self.assertEqual(shelf["set"], {1, 2, 3})
self.assertEqual(shelf["tuple"], (1, 2, 3))
self.assertEqual(shelf["complex"], 1 + 2j)
self.assertEqual(shelf["bytearray"], bytearray(b"test"))
self.assertEqual(shelf["array"], array.array("i", [1, 2, 3]))

from test import mapping_tests

for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Expand Down
10 changes: 6 additions & 4 deletions Modules/_dbmmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,9 @@ dbm_ass_sub_lock_held(PyObject *self, PyObject *v, PyObject *w)
dbmobject *dp = dbmobject_CAST(self);

if ( !PyArg_Parse(v, "s#", &krec.dptr, &tmp_size) ) {
PyErr_SetString(PyExc_TypeError,
"dbm mappings have bytes or string keys only");
PyErr_Format(PyExc_TypeError,
"dbm key returned %.100s for value %R But database keys must be bytes or str, not %.100s",
Py_TYPE(v)->tp_name, v, Py_TYPE(v)->tp_name);
return -1;
}
_dbm_state *state = PyType_GetModuleState(Py_TYPE(dp));
Expand All @@ -258,8 +259,9 @@ dbm_ass_sub_lock_held(PyObject *self, PyObject *v, PyObject *w)
}
} else {
if ( !PyArg_Parse(w, "s#", &drec.dptr, &tmp_size) ) {
PyErr_SetString(PyExc_TypeError,
"dbm mappings have bytes or string elements only");
PyErr_Format(PyExc_TypeError,
"dbm value returned %.100s for value %R But database values must be bytes or str, not %.100s",
Py_TYPE(w)->tp_name, w, Py_TYPE(w)->tp_name);
return -1;
}
drec.dsize = tmp_size;
Expand Down
9 changes: 5 additions & 4 deletions Modules/_gdbmmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ parse_datum(PyObject *o, datum *d, const char *failmsg)
Py_ssize_t size;
if (!PyArg_Parse(o, "s#", &d->dptr, &size)) {
if (failmsg != NULL) {
PyErr_SetString(PyExc_TypeError, failmsg);
PyErr_Format(PyExc_TypeError, failmsg, Py_TYPE(o)->tp_name, o, Py_TYPE(o)->tp_name);
}
return 0;
}
Expand Down Expand Up @@ -318,11 +318,12 @@ static int
gdbm_ass_sub_lock_held(PyObject *op, PyObject *v, PyObject *w)
{
datum krec, drec;
const char *failmsg = "gdbm mappings have bytes or string indices only";
const char *key_failmsg = "dbm key returned %.100s for value %R But database keys must be bytes or str, not %.100s";
const char *value_failmsg = "dbm value returned %.100s for value %R But database keys must be bytes or str, not %.100s";
gdbmobject *dp = _gdbmobject_CAST(op);
_gdbm_state *state = PyType_GetModuleState(Py_TYPE(dp));

if (!parse_datum(v, &krec, failmsg)) {
if (!parse_datum(v, &krec, key_failmsg)) {
return -1;
}
if (dp->di_dbm == NULL) {
Expand All @@ -343,7 +344,7 @@ gdbm_ass_sub_lock_held(PyObject *op, PyObject *v, PyObject *w)
}
}
else {
if (!parse_datum(w, &drec, failmsg)) {
if (!parse_datum(w, &drec, value_failmsg)) {
return -1;
}
errno = 0;
Expand Down
Loading