Skip to content

Commit d3581dd

Browse files
committed
Update discover_unit_tests.py
1 parent 777939d commit d3581dd

File tree

1 file changed

+107
-58
lines changed

1 file changed

+107
-58
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 107 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -50,31 +50,48 @@ class TestFunction:
5050

5151
class TestsCache:
5252
def __init__(self) -> None:
53-
self.connection = sqlite3.connect(codeflash_cache_db)
53+
# Enable WAL mode and proper concurrent access
54+
self.connection = sqlite3.connect(
55+
codeflash_cache_db,
56+
timeout=30.0, # 30 second timeout for database locks
57+
check_same_thread=False, # Allow use from multiple threads
58+
)
59+
# Enable WAL mode for better concurrent access
60+
self.connection.execute("PRAGMA journal_mode=WAL")
61+
self.connection.execute("PRAGMA synchronous=NORMAL")
62+
self.connection.execute("PRAGMA cache_size=10000")
63+
self.connection.execute("PRAGMA temp_store=MEMORY")
64+
5465
logger.debug(f"Connected to tests cache database at {codeflash_cache_db}")
5566
self.cur = self.connection.cursor()
5667

57-
self.cur.execute(
58-
"""
59-
CREATE TABLE IF NOT EXISTS discovered_tests(
60-
file_path TEXT,
61-
file_hash TEXT,
62-
qualified_name_with_modules_from_root TEXT,
63-
function_name TEXT,
64-
test_class TEXT,
65-
test_function TEXT,
66-
test_type TEXT,
67-
line_number INTEGER,
68-
col_number INTEGER
68+
try:
69+
self.cur.execute(
70+
"""
71+
CREATE TABLE IF NOT EXISTS discovered_tests(
72+
file_path TEXT,
73+
file_hash TEXT,
74+
qualified_name_with_modules_from_root TEXT,
75+
function_name TEXT,
76+
test_class TEXT,
77+
test_function TEXT,
78+
test_type TEXT,
79+
line_number INTEGER,
80+
col_number INTEGER,
81+
UNIQUE(file_path, file_hash, qualified_name_with_modules_from_root, test_function, test_class, line_number, col_number)
82+
)
83+
"""
6984
)
70-
"""
71-
)
72-
self.cur.execute(
73-
"""
74-
CREATE INDEX IF NOT EXISTS idx_discovered_tests_file_path_hash
75-
ON discovered_tests (file_path, file_hash)
76-
"""
77-
)
85+
self.cur.execute(
86+
"""
87+
CREATE INDEX IF NOT EXISTS idx_discovered_tests_file_path_hash
88+
ON discovered_tests (file_path, file_hash)
89+
"""
90+
)
91+
self.connection.commit()
92+
except sqlite3.OperationalError as e:
93+
logger.info(f"Database initialization warning (likely concurrent access): {e}")
94+
7895
self._memory_cache = {}
7996

8097
def insert_test(
@@ -90,38 +107,53 @@ def insert_test(
90107
col_number: int,
91108
) -> None:
92109
test_type_value = test_type.value if hasattr(test_type, "value") else test_type
93-
self.cur.execute(
94-
"INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
95-
(
96-
file_path,
97-
file_hash,
98-
qualified_name_with_modules_from_root,
99-
function_name,
100-
test_class,
101-
test_function,
102-
test_type_value,
103-
line_number,
104-
col_number,
105-
),
106-
)
107-
self.connection.commit()
110+
try:
111+
self.cur.execute(
112+
"INSERT OR IGNORE INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
113+
(
114+
file_path,
115+
file_hash,
116+
qualified_name_with_modules_from_root,
117+
function_name,
118+
test_class,
119+
test_function,
120+
test_type_value,
121+
line_number,
122+
col_number,
123+
),
124+
)
125+
self.connection.commit()
126+
except sqlite3.OperationalError as e:
127+
logger.info(f"Database insert warning (likely concurrent access): {e}")
128+
except sqlite3.DatabaseError as e:
129+
logger.info(f"Database error during insert: {e}")
108130

109131
def get_tests_for_file(self, file_path: str, file_hash: str) -> list[FunctionCalledInTest]:
110132
cache_key = (file_path, file_hash)
111133
if cache_key in self._memory_cache:
112134
return self._memory_cache[cache_key]
113-
self.cur.execute("SELECT * FROM discovered_tests WHERE file_path = ? AND file_hash = ?", (file_path, file_hash))
114-
result = [
115-
FunctionCalledInTest(
116-
tests_in_file=TestsInFile(
117-
test_file=Path(row[0]), test_class=row[4], test_function=row[5], test_type=TestType(int(row[6]))
118-
),
119-
position=CodePosition(line_no=row[7], col_no=row[8]),
135+
136+
try:
137+
self.cur.execute(
138+
"SELECT * FROM discovered_tests WHERE file_path = ? AND file_hash = ?", (file_path, file_hash)
120139
)
121-
for row in self.cur.fetchall()
122-
]
123-
self._memory_cache[cache_key] = result
124-
return result
140+
result = [
141+
FunctionCalledInTest(
142+
tests_in_file=TestsInFile(
143+
test_file=Path(row[0]), test_class=row[4], test_function=row[5], test_type=TestType(int(row[6]))
144+
),
145+
position=CodePosition(line_no=row[7], col_no=row[8]),
146+
)
147+
for row in self.cur.fetchall()
148+
]
149+
self._memory_cache[cache_key] = result
150+
return result # noqa: TRY300
151+
except sqlite3.OperationalError as e:
152+
logger.info(f"Database query warning (likely concurrent access): {e}")
153+
return []
154+
except sqlite3.DatabaseError as e:
155+
logger.info(f"Database error during query: {e}")
156+
return []
125157

126158
@staticmethod
127159
def compute_file_hash(path: str) -> str:
@@ -135,8 +167,13 @@ def compute_file_hash(path: str) -> str:
135167
return h.hexdigest()
136168

137169
def close(self) -> None:
138-
self.cur.close()
139-
self.connection.close()
170+
try:
171+
if self.cur:
172+
self.cur.close()
173+
if self.connection:
174+
self.connection.close()
175+
except sqlite3.Error as e:
176+
logger.info(f"Database close warning: {e}")
140177

141178

142179
def discover_unit_tests(
@@ -475,15 +512,27 @@ def process_test_files(
475512

476513
# Process cached files first
477514
for test_file, (_functions, cached_tests, file_hash) in cached_files.items():
478-
cur = tests_cache.cur
479-
cur.execute(
480-
"SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?",
481-
(str(test_file), file_hash),
482-
)
483-
qualified_names = [row[0] for row in cur.fetchall()]
484-
for cached_test, qualified_name in zip(cached_tests, qualified_names):
485-
function_to_test_map[qualified_name].add(cached_test)
486-
total_count += len(cached_tests)
515+
try:
516+
cur = tests_cache.cur
517+
cur.execute(
518+
"SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?",
519+
(str(test_file), file_hash),
520+
)
521+
qualified_names = [row[0] for row in cur.fetchall()]
522+
523+
if len(qualified_names) == len(cached_tests):
524+
for cached_test, qualified_name in zip(cached_tests, qualified_names):
525+
function_to_test_map[qualified_name].add(cached_test)
526+
total_count += len(cached_tests)
527+
else:
528+
logger.info(
529+
f"Cache mismatch for {test_file}: expected {len(cached_tests)} names, got {len(qualified_names)}"
530+
)
531+
uncached_files[test_file] = _functions
532+
except sqlite3.Error as e:
533+
logger.info(f"Database error accessing cached data for {test_file}: {e}")
534+
uncached_files[test_file] = _functions
535+
487536
progress.advance(task_id)
488537

489538
if len(uncached_files) == 1 or max_workers == 1:

0 commit comments

Comments
 (0)