Skip to content

Commit 24ee255

Browse files
committed
Enhance shelve serializer validation with descriptive error messages
1 parent d8a9466 commit 24ee255

File tree

2 files changed

+92
-2
lines changed

2 files changed

+92
-2
lines changed

Lib/shelve.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,19 @@ def __init__(self, dict, protocol=None, writeback=False,
106106
self.serializer = serializer
107107
self.deserializer = deserializer
108108

109+
@staticmethod
110+
def _validate_serialized_value(serialized_value, original_value):
111+
if (serialized_value is None or
112+
not isinstance(serialized_value, (bytes, str))):
113+
if serialized_value is None:
114+
invalid_type = "None"
115+
else:
116+
invalid_type = type(serialized_value).__name__
117+
msg = (f"Serializer returned {invalid_type} for value "
118+
f"{original_value!r} But database values must be "
119+
f"bytes or str, not {invalid_type}")
120+
raise ShelveError(msg)
121+
109122
def __iter__(self):
110123
for k in self.dict.keys():
111124
yield k.decode(self.keyencoding)
@@ -135,6 +148,7 @@ def __setitem__(self, key, value):
135148
if self.writeback:
136149
self.cache[key] = value
137150
serialized_value = self.serializer(value, self._protocol)
151+
self._validate_serialized_value(serialized_value, value)
138152
self.dict[key.encode(self.keyencoding)] = serialized_value
139153

140154
def __delitem__(self, key):

Lib/test/test_shelve.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ def test_custom_serializer_and_deserializer(self):
173173
def serializer(obj, protocol):
174174
if isinstance(obj, (bytes, bytearray, str)):
175175
if protocol == 5:
176+
if isinstance(obj, bytearray):
177+
return bytes(obj) # DBM backends expect bytes
176178
return obj
177179
return type(obj).__name__
178180
elif isinstance(obj, array.array):
@@ -223,11 +225,10 @@ def deserializer(data):
223225
)
224226

225227
def test_custom_incomplete_serializer_and_deserializer(self):
226-
dbm_sqlite3 = import_helper.import_module("dbm.sqlite3")
227228
os.mkdir(self.dirname)
228229
self.addCleanup(os_helper.rmtree, self.dirname)
229230

230-
with self.assertRaises(dbm_sqlite3.error):
231+
with self.assertRaises(shelve.ShelveError):
231232
def serializer(obj, protocol=None):
232233
pass
233234

@@ -430,6 +431,81 @@ def setUp(self):
430431
dbm._defaultmod = self.dbm_mod
431432

432433

434+
class TestShelveValidation(unittest.TestCase):
435+
dirname = os_helper.TESTFN
436+
fname = os.path.join(dirname, os_helper.TESTFN)
437+
438+
def setup_test_dir(self):
439+
os_helper.rmtree(self.dirname)
440+
os.mkdir(self.dirname)
441+
442+
def setUp(self):
443+
self.addCleanup(setattr, dbm, "_defaultmod", dbm._defaultmod)
444+
os.mkdir(self.dirname)
445+
self.addCleanup(os_helper.rmtree, self.dirname)
446+
447+
def test_serializer_unsupported_return_type(self):
448+
def int_serializer(obj, protocol=None):
449+
return 3
450+
451+
def none_serializer(obj, protocol=None):
452+
return None
453+
454+
def deserializer(data):
455+
if isinstance(data, bytes):
456+
return data.decode("utf-8")
457+
else:
458+
return data
459+
460+
for module in dbm_iterator():
461+
self.setup_test_dir()
462+
dbm._defaultmod = module
463+
with module.open(self.fname, "c"):
464+
pass
465+
self.assertEqual(module.__name__, dbm.whichdb(self.fname))
466+
467+
with shelve.open(self.fname, serializer=none_serializer,
468+
deserializer=deserializer) as s:
469+
with self.assertRaises(shelve.ShelveError) as cm:
470+
s["key"] = "value"
471+
self.assertEqual("Serializer returned None for value 'value' "
472+
"But database values must be bytes or str, not None",
473+
f"{cm.exception}")
474+
475+
with shelve.open(self.fname, serializer=int_serializer,
476+
deserializer=deserializer,) as s:
477+
with self.assertRaises(shelve.ShelveError) as cm:
478+
s["key"] = "value"
479+
self.assertEqual("Serializer returned int for value 'value' "
480+
"But database values must be bytes or str, not int",
481+
f"{cm.exception}")
482+
483+
def test_shelve_type_compatibility(self):
484+
for module in dbm_iterator():
485+
self.setup_test_dir()
486+
dbm._defaultmod = module
487+
with shelve.Shelf(module.open(self.fname, "c")) as shelf:
488+
shelf["string"] = "hello"
489+
shelf["bytes"] = b"world"
490+
shelf["number"] = 42
491+
shelf["list"] = [1, 2, 3]
492+
shelf["dict"] = {"key": "value"}
493+
shelf["set"] = {1, 2, 3}
494+
shelf["tuple"] = (1, 2, 3)
495+
shelf["complex"] = 1 + 2j
496+
shelf["bytearray"] = bytearray(b"test")
497+
shelf["array"] = array.array("i", [1, 2, 3])
498+
self.assertEqual(shelf["string"], "hello")
499+
self.assertEqual(shelf["bytes"], b"world")
500+
self.assertEqual(shelf["number"], 42)
501+
self.assertEqual(shelf["list"], [1, 2, 3])
502+
self.assertEqual(shelf["dict"], {"key": "value"})
503+
self.assertEqual(shelf["set"], {1, 2, 3})
504+
self.assertEqual(shelf["tuple"], (1, 2, 3))
505+
self.assertEqual(shelf["complex"], 1 + 2j)
506+
self.assertEqual(shelf["bytearray"], bytearray(b"test"))
507+
self.assertEqual(shelf["array"], array.array("i", [1, 2, 3]))
508+
433509
from test import mapping_tests
434510

435511
for proto in range(pickle.HIGHEST_PROTOCOL + 1):

0 commit comments

Comments
 (0)