| 
 | 1 | +import unittest  | 
 | 2 | + | 
 | 3 | +from test.support import import_helper, os_helper, threading_helper  | 
 | 4 | +from test.support.threading_helper import run_concurrently  | 
 | 5 | + | 
 | 6 | +import threading  | 
 | 7 | + | 
 | 8 | +gdbm = import_helper.import_module("dbm.gnu")  | 
 | 9 | + | 
 | 10 | +NTHREADS = 10  | 
 | 11 | +KEY_PER_THREAD = 1000  | 
 | 12 | + | 
 | 13 | +gdbm_filename = "test_gdbm_file"  | 
 | 14 | + | 
 | 15 | + | 
 | 16 | +@threading_helper.requires_working_threading()  | 
 | 17 | +class TestGdbm(unittest.TestCase):  | 
 | 18 | +    def test_racing_dbm_gnu(self):  | 
 | 19 | +        def gdbm_multi_op_worker(db):  | 
 | 20 | +            # Each thread sets, gets, and iterates  | 
 | 21 | +            tid = threading.get_ident()  | 
 | 22 | + | 
 | 23 | +            # Insert keys  | 
 | 24 | +            for i in range(KEY_PER_THREAD):  | 
 | 25 | +                db[f"key_{tid}_{i}"] = f"value_{tid}_{i}"  | 
 | 26 | + | 
 | 27 | +            for i in range(KEY_PER_THREAD):  | 
 | 28 | +                # Keys and values are stored as bytes; encode values for  | 
 | 29 | +                # comparison  | 
 | 30 | +                key = f"key_{tid}_{i}"  | 
 | 31 | +                value = f"value_{tid}_{i}".encode()  | 
 | 32 | +                self.assertIn(key, db)  | 
 | 33 | +                self.assertEqual(db[key], value)  | 
 | 34 | +                self.assertEqual(db.get(key), value)  | 
 | 35 | +                self.assertIsNone(db.get("not_exist"))  | 
 | 36 | +                with self.assertRaises(KeyError):  | 
 | 37 | +                    db["not_exist"]  | 
 | 38 | + | 
 | 39 | +            # Iterate over the database keys and verify only those belonging  | 
 | 40 | +            # to this thread. Other threads may concurrently delete their keys.  | 
 | 41 | +            key_prefix = f"key_{tid}".encode()  | 
 | 42 | +            key = db.firstkey()  | 
 | 43 | +            key_count = 0  | 
 | 44 | +            while key:  | 
 | 45 | +                if key.startswith(key_prefix):  | 
 | 46 | +                    self.assertIn(key, db)  | 
 | 47 | +                    key_count += 1  | 
 | 48 | +                key = db.nextkey(key)  | 
 | 49 | + | 
 | 50 | +            # Can't assert key_count == KEY_PER_THREAD because concurrent  | 
 | 51 | +            # threads may insert or delete keys during iteration. This can  | 
 | 52 | +            # cause keys to be skipped or counted multiple times, making the  | 
 | 53 | +            # count unreliable.  | 
 | 54 | +            # See: https://www.gnu.org.ua/software/gdbm/manual/Sequential.html  | 
 | 55 | +            # self.assertEqual(key_count, KEY_PER_THREAD)  | 
 | 56 | + | 
 | 57 | +            # Delete this thread's keys  | 
 | 58 | +            for i in range(KEY_PER_THREAD):  | 
 | 59 | +                key = f"key_{tid}_{i}"  | 
 | 60 | +                del db[key]  | 
 | 61 | +                self.assertNotIn(key, db)  | 
 | 62 | +                with self.assertRaises(KeyError):  | 
 | 63 | +                    del db["not_exist"]  | 
 | 64 | + | 
 | 65 | +            # Re-insert keys  | 
 | 66 | +            for i in range(KEY_PER_THREAD):  | 
 | 67 | +                db[f"key_{tid}_{i}"] = f"value_{tid}_{i}"  | 
 | 68 | + | 
 | 69 | +        with os_helper.temp_dir() as tmpdirname:  | 
 | 70 | +            db = gdbm.open(f"{tmpdirname}/{gdbm_filename}", "c")  | 
 | 71 | +            run_concurrently(  | 
 | 72 | +                worker_func=gdbm_multi_op_worker, nthreads=NTHREADS, args=(db,)  | 
 | 73 | +            )  | 
 | 74 | +            self.assertEqual(len(db), NTHREADS * KEY_PER_THREAD)  | 
 | 75 | +            db.close()  | 
 | 76 | + | 
 | 77 | + | 
 | 78 | +if __name__ == "__main__":  | 
 | 79 | +    unittest.main()  | 
0 commit comments