Skip to content
45 changes: 37 additions & 8 deletions Lib/test/test_shelve.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ def test_custom_serializer_and_deserializer(self):
def serializer(obj, protocol):
if isinstance(obj, (bytes, bytearray, str)):
if protocol == 5:
if isinstance(obj, bytearray):
return bytes(obj) # DBM backends expect bytes
return obj
return type(obj).__name__
elif isinstance(obj, array.array):
Expand Down Expand Up @@ -222,22 +224,44 @@ def deserializer(data):
s["array_data"], array_data.tobytes().decode()
)

def test_custom_incomplete_serializer_and_deserializer(self):
dbm_sqlite3 = import_helper.import_module("dbm.sqlite3")
def test_custom_incomplete_serializer(self):
os.mkdir(self.dirname)
self.addCleanup(os_helper.rmtree, self.dirname)

with self.assertRaises(dbm_sqlite3.error):
def serializer(obj, protocol=None):
pass
def serializer(obj, protocol=None):
pass

def deserializer(data):
return data.decode("utf-8")
def deserializer(data):
return data.decode("utf-8")

# Since the serializer returns None, dbm.error is raised
# by dbm.sqlite3 and TypeError is raised by other backends.
with self.assertRaises((TypeError, dbm.error)):
with shelve.open(self.fn, serializer=serializer,
deserializer=deserializer) as s:
s["foo"] = "bar"

def test_custom_invalid_serializer(self):
os.mkdir(self.dirname)
self.addCleanup(os_helper.rmtree, self.dirname)

def serializer(obj, protocol=None):
return ["value with invalid type"]
Copy link
Member

@picnixz picnixz Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to test with None and with something that is not None, I would suggest using a parametrized test:

@support.subTests("serialized", [None, ["invalid type"]])
def test_custom_invalid_serializer(self, serialized):
    ...
    def serializer(obj, protocol=None):
        return serialized
    ...

and update the comment saying that the TypeError is due to the return
value of the serializer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! I've implemented the parameterized test approach as you recommended. I explored two different implementations and wanted to share the trade-offs:

Option 1: Using @subTests decorator (as you suggested)

@subTests("serialized", [None, ["invalid type"]])
def test_custom_invalid_serializer(self, serialized):
    # Create unique directory for each subtest to avoid conflicts
    test_dir = f"{self.dirname}_{id(serialized)}"
    os.mkdir(test_dir)
    self.addCleanup(os_helper.rmtree, test_dir)
    test_fn = os.path.join(test_dir, "shelftemp.db")
    
    def serializer(obj, protocol=None):
        return serialized
    # ... rest of test logic

Option 2: Using self.subTest() with a simple loop

def test_custom_invalid_serializer(self):
    os.mkdir(self.dirname)
    self.addCleanup(os_helper.rmtree, self.dirname)
    
    for serialized in [None, ["invalid type"]]:
        with self.subTest(serialized=serialized):
            def serializer(obj, protocol=None):
                return serialized
            # ... rest of test logic

Trade-offs:

  • Option 1: More explicit parameterization and complete isolation, but requires complex directory management
  • Option 2: Simpler, more readable, and uses standard unittest patterns, but subtests share the same directory

Which approach do you prefer?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use 1, it was meant for such cases.


def deserializer(data):
return data.decode("utf-8")

# Since the serializer returns None, dbm.error is raised
# by dbm.sqlite3 and TypeError is raised by other backends.
with self.assertRaises((TypeError, dbm.error)):
with shelve.open(self.fn, serializer=serializer,
deserializer=deserializer) as s:
s["foo"] = "bar"

def test_custom_incomplete_deserializer(self):
os.mkdir(self.dirname)
self.addCleanup(os_helper.rmtree, self.dirname)

def serializer(obj, protocol=None):
return type(obj).__name__.encode("utf-8")

Expand Down Expand Up @@ -352,7 +376,7 @@ def type_name_len(obj):
self.assertEqual(s["bytearray_data"], "bytearray")
self.assertEqual(s["array_data"], "array")

def test_custom_incomplete_serializer_and_deserializer_bsd_db_shelf(self):
def test_custom_incomplete_deserializer_bsd_db_shelf(self):
berkeleydb = import_helper.import_module("berkeleydb")
os.mkdir(self.dirname)
self.addCleanup(os_helper.rmtree, self.dirname)
Expand All @@ -370,6 +394,11 @@ def deserializer(data):
self.assertIsNone(s["foo"])
self.assertNotEqual(s["foo"], "bar")

def test_custom_incomplete_serializer_bsd_db_shelf(self):
berkeleydb = import_helper.import_module("berkeleydb")
os.mkdir(self.dirname)
self.addCleanup(os_helper.rmtree, self.dirname)

def serializer(obj, protocol=None):
pass

Expand Down
Loading