@@ -50,31 +50,48 @@ class TestFunction:
5050
5151class TestsCache :
5252 def __init__ (self ) -> None :
53- self .connection = sqlite3 .connect (codeflash_cache_db )
53+ # Enable WAL mode and proper concurrent access
54+ self .connection = sqlite3 .connect (
55+ codeflash_cache_db ,
56+ timeout = 30.0 , # 30 second timeout for database locks
57+ check_same_thread = False , # Allow use from multiple threads
58+ )
59+ # Enable WAL mode for better concurrent access
60+ self .connection .execute ("PRAGMA journal_mode=WAL" )
61+ self .connection .execute ("PRAGMA synchronous=NORMAL" )
62+ self .connection .execute ("PRAGMA cache_size=10000" )
63+ self .connection .execute ("PRAGMA temp_store=MEMORY" )
64+
5465 logger .debug (f"Connected to tests cache database at { codeflash_cache_db } " )
5566 self .cur = self .connection .cursor ()
5667
57- self .cur .execute (
58- """
59- CREATE TABLE IF NOT EXISTS discovered_tests(
60- file_path TEXT,
61- file_hash TEXT,
62- qualified_name_with_modules_from_root TEXT,
63- function_name TEXT,
64- test_class TEXT,
65- test_function TEXT,
66- test_type TEXT,
67- line_number INTEGER,
68- col_number INTEGER
68+ try :
69+ self .cur .execute (
70+ """
71+ CREATE TABLE IF NOT EXISTS discovered_tests(
72+ file_path TEXT,
73+ file_hash TEXT,
74+ qualified_name_with_modules_from_root TEXT,
75+ function_name TEXT,
76+ test_class TEXT,
77+ test_function TEXT,
78+ test_type TEXT,
79+ line_number INTEGER,
80+ col_number INTEGER,
81+ UNIQUE(file_path, file_hash, qualified_name_with_modules_from_root, test_function, test_class, line_number, col_number)
82+ )
83+ """
6984 )
70- """
71- )
72- self .cur .execute (
73- """
74- CREATE INDEX IF NOT EXISTS idx_discovered_tests_file_path_hash
75- ON discovered_tests (file_path, file_hash)
76- """
77- )
85+ self .cur .execute (
86+ """
87+ CREATE INDEX IF NOT EXISTS idx_discovered_tests_file_path_hash
88+ ON discovered_tests (file_path, file_hash)
89+ """
90+ )
91+ self .connection .commit ()
92+ except sqlite3 .OperationalError as e :
93+ logger .info (f"Database initialization warning (likely concurrent access): { e } " )
94+
7895 self ._memory_cache = {}
7996
8097 def insert_test (
@@ -90,38 +107,53 @@ def insert_test(
90107 col_number : int ,
91108 ) -> None :
92109 test_type_value = test_type .value if hasattr (test_type , "value" ) else test_type
93- self .cur .execute (
94- "INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)" ,
95- (
96- file_path ,
97- file_hash ,
98- qualified_name_with_modules_from_root ,
99- function_name ,
100- test_class ,
101- test_function ,
102- test_type_value ,
103- line_number ,
104- col_number ,
105- ),
106- )
107- self .connection .commit ()
110+ try :
111+ self .cur .execute (
112+ "INSERT OR IGNORE INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)" ,
113+ (
114+ file_path ,
115+ file_hash ,
116+ qualified_name_with_modules_from_root ,
117+ function_name ,
118+ test_class ,
119+ test_function ,
120+ test_type_value ,
121+ line_number ,
122+ col_number ,
123+ ),
124+ )
125+ self .connection .commit ()
126+ except sqlite3 .OperationalError as e :
127+ logger .info (f"Database insert warning (likely concurrent access): { e } " )
128+ except sqlite3 .DatabaseError as e :
129+ logger .info (f"Database error during insert: { e } " )
108130
109131 def get_tests_for_file (self , file_path : str , file_hash : str ) -> list [FunctionCalledInTest ]:
110132 cache_key = (file_path , file_hash )
111133 if cache_key in self ._memory_cache :
112134 return self ._memory_cache [cache_key ]
113- self .cur .execute ("SELECT * FROM discovered_tests WHERE file_path = ? AND file_hash = ?" , (file_path , file_hash ))
114- result = [
115- FunctionCalledInTest (
116- tests_in_file = TestsInFile (
117- test_file = Path (row [0 ]), test_class = row [4 ], test_function = row [5 ], test_type = TestType (int (row [6 ]))
118- ),
119- position = CodePosition (line_no = row [7 ], col_no = row [8 ]),
135+
136+ try :
137+ self .cur .execute (
138+ "SELECT * FROM discovered_tests WHERE file_path = ? AND file_hash = ?" , (file_path , file_hash )
120139 )
121- for row in self .cur .fetchall ()
122- ]
123- self ._memory_cache [cache_key ] = result
124- return result
140+ result = [
141+ FunctionCalledInTest (
142+ tests_in_file = TestsInFile (
143+ test_file = Path (row [0 ]), test_class = row [4 ], test_function = row [5 ], test_type = TestType (int (row [6 ]))
144+ ),
145+ position = CodePosition (line_no = row [7 ], col_no = row [8 ]),
146+ )
147+ for row in self .cur .fetchall ()
148+ ]
149+ self ._memory_cache [cache_key ] = result
150+ return result # noqa: TRY300
151+ except sqlite3 .OperationalError as e :
152+ logger .info (f"Database query warning (likely concurrent access): { e } " )
153+ return []
154+ except sqlite3 .DatabaseError as e :
155+ logger .info (f"Database error during query: { e } " )
156+ return []
125157
126158 @staticmethod
127159 def compute_file_hash (path : str ) -> str :
@@ -135,8 +167,13 @@ def compute_file_hash(path: str) -> str:
135167 return h .hexdigest ()
136168
137169 def close (self ) -> None :
138- self .cur .close ()
139- self .connection .close ()
170+ try :
171+ if self .cur :
172+ self .cur .close ()
173+ if self .connection :
174+ self .connection .close ()
175+ except sqlite3 .Error as e :
176+ logger .info (f"Database close warning: { e } " )
140177
141178
142179def discover_unit_tests (
@@ -475,15 +512,27 @@ def process_test_files(
475512
476513 # Process cached files first
477514 for test_file , (_functions , cached_tests , file_hash ) in cached_files .items ():
478- cur = tests_cache .cur
479- cur .execute (
480- "SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?" ,
481- (str (test_file ), file_hash ),
482- )
483- qualified_names = [row [0 ] for row in cur .fetchall ()]
484- for cached_test , qualified_name in zip (cached_tests , qualified_names ):
485- function_to_test_map [qualified_name ].add (cached_test )
486- total_count += len (cached_tests )
515+ try :
516+ cur = tests_cache .cur
517+ cur .execute (
518+ "SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?" ,
519+ (str (test_file ), file_hash ),
520+ )
521+ qualified_names = [row [0 ] for row in cur .fetchall ()]
522+
523+ if len (qualified_names ) == len (cached_tests ):
524+ for cached_test , qualified_name in zip (cached_tests , qualified_names ):
525+ function_to_test_map [qualified_name ].add (cached_test )
526+ total_count += len (cached_tests )
527+ else :
528+ logger .info (
529+ f"Cache mismatch for { test_file } : expected { len (cached_tests )} names, got { len (qualified_names )} "
530+ )
531+ uncached_files [test_file ] = _functions
532+ except sqlite3 .Error as e :
533+ logger .info (f"Database error accessing cached data for { test_file } : { e } " )
534+ uncached_files [test_file ] = _functions
535+
487536 progress .advance (task_id )
488537
489538 if len (uncached_files ) == 1 or max_workers == 1 :
0 commit comments