@@ -173,6 +173,8 @@ def test_custom_serializer_and_deserializer(self):
173173 def serializer (obj , protocol ):
174174 if isinstance (obj , (bytes , bytearray , str )):
175175 if protocol == 5 :
176+ if isinstance (obj , bytearray ):
177+ return bytes (obj ) # DBM backends expect bytes
176178 return obj
177179 return type (obj ).__name__
178180 elif isinstance (obj , array .array ):
@@ -223,11 +225,10 @@ def deserializer(data):
223225 )
224226
225227 def test_custom_incomplete_serializer_and_deserializer (self ):
226- dbm_sqlite3 = import_helper .import_module ("dbm.sqlite3" )
227228 os .mkdir (self .dirname )
228229 self .addCleanup (os_helper .rmtree , self .dirname )
229230
230- with self .assertRaises (dbm_sqlite3 . error ):
231+ with self .assertRaises (shelve . ShelveError ):
231232 def serializer (obj , protocol = None ):
232233 pass
233234
@@ -430,6 +431,81 @@ def setUp(self):
430431 dbm ._defaultmod = self .dbm_mod
431432
432433
434+ class TestShelveValidation (unittest .TestCase ):
435+ dirname = os_helper .TESTFN
436+ fname = os .path .join (dirname , os_helper .TESTFN )
437+
438+ def setup_test_dir (self ):
439+ os_helper .rmtree (self .dirname )
440+ os .mkdir (self .dirname )
441+
442+ def setUp (self ):
443+ self .addCleanup (setattr , dbm , "_defaultmod" , dbm ._defaultmod )
444+ os .mkdir (self .dirname )
445+ self .addCleanup (os_helper .rmtree , self .dirname )
446+
447+ def test_serializer_unsupported_return_type (self ):
448+ def int_serializer (obj , protocol = None ):
449+ return 3
450+
451+ def none_serializer (obj , protocol = None ):
452+ return None
453+
454+ def deserializer (data ):
455+ if isinstance (data , bytes ):
456+ return data .decode ("utf-8" )
457+ else :
458+ return data
459+
460+ for module in dbm_iterator ():
461+ self .setup_test_dir ()
462+ dbm ._defaultmod = module
463+ with module .open (self .fname , "c" ):
464+ pass
465+ self .assertEqual (module .__name__ , dbm .whichdb (self .fname ))
466+
467+ with shelve .open (self .fname , serializer = none_serializer ,
468+ deserializer = deserializer ) as s :
469+ with self .assertRaises (shelve .ShelveError ) as cm :
470+ s ["key" ] = "value"
471+ self .assertEqual ("Serializer returned None for value 'value' "
472+ "But database values must be bytes or str, not None" ,
473+ f"{ cm .exception } " )
474+
475+ with shelve .open (self .fname , serializer = int_serializer ,
476+ deserializer = deserializer ,) as s :
477+ with self .assertRaises (shelve .ShelveError ) as cm :
478+ s ["key" ] = "value"
479+ self .assertEqual ("Serializer returned int for value 'value' "
480+ "But database values must be bytes or str, not int" ,
481+ f"{ cm .exception } " )
482+
483+ def test_shelve_type_compatibility (self ):
484+ for module in dbm_iterator ():
485+ self .setup_test_dir ()
486+ dbm ._defaultmod = module
487+ with shelve .Shelf (module .open (self .fname , "c" )) as shelf :
488+ shelf ["string" ] = "hello"
489+ shelf ["bytes" ] = b"world"
490+ shelf ["number" ] = 42
491+ shelf ["list" ] = [1 , 2 , 3 ]
492+ shelf ["dict" ] = {"key" : "value" }
493+ shelf ["set" ] = {1 , 2 , 3 }
494+ shelf ["tuple" ] = (1 , 2 , 3 )
495+ shelf ["complex" ] = 1 + 2j
496+ shelf ["bytearray" ] = bytearray (b"test" )
497+ shelf ["array" ] = array .array ("i" , [1 , 2 , 3 ])
498+ self .assertEqual (shelf ["string" ], "hello" )
499+ self .assertEqual (shelf ["bytes" ], b"world" )
500+ self .assertEqual (shelf ["number" ], 42 )
501+ self .assertEqual (shelf ["list" ], [1 , 2 , 3 ])
502+ self .assertEqual (shelf ["dict" ], {"key" : "value" })
503+ self .assertEqual (shelf ["set" ], {1 , 2 , 3 })
504+ self .assertEqual (shelf ["tuple" ], (1 , 2 , 3 ))
505+ self .assertEqual (shelf ["complex" ], 1 + 2j )
506+ self .assertEqual (shelf ["bytearray" ], bytearray (b"test" ))
507+ self .assertEqual (shelf ["array" ], array .array ("i" , [1 , 2 , 3 ]))
508+
433509from test import mapping_tests
434510
435511for proto in range (pickle .HIGHEST_PROTOCOL + 1 ):
0 commit comments