Skip to content

Commit 88c671a

Browse files
authored
Merge pull request #184 from codeflash-ai/cache_discovered_tests
Cache discovered tests
2 parents 8774060 + dd4f3be commit 88c671a

File tree

6 files changed

+188
-100
lines changed

6 files changed

+188
-100
lines changed

codeflash/code_utils/code_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import ast
44
import os
5+
import shutil
56
import site
67
from functools import lru_cache
78
from pathlib import Path
@@ -118,4 +119,8 @@ def has_any_async_functions(code: str) -> bool:
118119

119120
def cleanup_paths(paths: list[Path]) -> None:
120121
for path in paths:
121-
path.unlink(missing_ok=True)
122+
if path and path.exists():
123+
if path.is_dir():
124+
shutil.rmtree(path, ignore_errors=True)
125+
else:
126+
path.unlink(missing_ok=True)

codeflash/code_utils/compat.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import sys
33
from pathlib import Path
44

5+
from platformdirs import user_config_dir
6+
57
# os-independent newline
68
# important for any user-facing output or files we write
79
# make sure to use this in f-strings e.g. f"some string{LF}"
@@ -12,3 +14,8 @@
1214
SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix()
1315

1416
IS_POSIX = os.name != "nt"
17+
18+
19+
codeflash_cache_dir = Path(user_config_dir(appname="codeflash", appauthor="codeflash-ai", ensure_exists=True))
20+
21+
codeflash_cache_db = codeflash_cache_dir / "codeflash_cache.db"

codeflash/discovery/discover_unit_tests.py

Lines changed: 150 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

3+
import hashlib
34
import os
45
import pickle
56
import re
7+
import sqlite3
68
import subprocess
79
import unittest
810
from collections import defaultdict
@@ -15,7 +17,7 @@
1517

1618
from codeflash.cli_cmds.console import console, logger, test_files_progress_bar
1719
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
18-
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
20+
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_cache_db
1921
from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType
2022

2123
if TYPE_CHECKING:
@@ -37,13 +39,101 @@ class TestFunction:
3739
FUNCTION_NAME_REGEX = re.compile(r"([^.]+)\.([a-zA-Z0-9_]+)$")
3840

3941

42+
class TestsCache:
43+
def __init__(self) -> None:
44+
self.connection = sqlite3.connect(codeflash_cache_db)
45+
self.cur = self.connection.cursor()
46+
47+
self.cur.execute(
48+
"""
49+
CREATE TABLE IF NOT EXISTS discovered_tests(
50+
file_path TEXT,
51+
file_hash TEXT,
52+
qualified_name_with_modules_from_root TEXT,
53+
function_name TEXT,
54+
test_class TEXT,
55+
test_function TEXT,
56+
test_type TEXT,
57+
line_number INTEGER,
58+
col_number INTEGER
59+
)
60+
"""
61+
)
62+
self.cur.execute(
63+
"""
64+
CREATE INDEX IF NOT EXISTS idx_discovered_tests_file_path_hash
65+
ON discovered_tests (file_path, file_hash)
66+
"""
67+
)
68+
self._memory_cache = {}
69+
70+
def insert_test(
71+
self,
72+
file_path: str,
73+
file_hash: str,
74+
qualified_name_with_modules_from_root: str,
75+
function_name: str,
76+
test_class: str,
77+
test_function: str,
78+
test_type: TestType,
79+
line_number: int,
80+
col_number: int,
81+
) -> None:
82+
self.cur.execute("DELETE FROM discovered_tests WHERE file_path = ?", (file_path,))
83+
test_type_value = test_type.value if hasattr(test_type, "value") else test_type
84+
self.cur.execute(
85+
"INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
86+
(
87+
file_path,
88+
file_hash,
89+
qualified_name_with_modules_from_root,
90+
function_name,
91+
test_class,
92+
test_function,
93+
test_type_value,
94+
line_number,
95+
col_number,
96+
),
97+
)
98+
self.connection.commit()
99+
100+
def get_tests_for_file(self, file_path: str, file_hash: str) -> list[FunctionCalledInTest]:
101+
cache_key = (file_path, file_hash)
102+
if cache_key in self._memory_cache:
103+
return self._memory_cache[cache_key]
104+
self.cur.execute("SELECT * FROM discovered_tests WHERE file_path = ? AND file_hash = ?", (file_path, file_hash))
105+
result = [
106+
FunctionCalledInTest(
107+
tests_in_file=TestsInFile(
108+
test_file=Path(row[0]), test_class=row[4], test_function=row[5], test_type=TestType(int(row[6]))
109+
),
110+
position=CodePosition(line_no=row[7], col_no=row[8]),
111+
)
112+
for row in self.cur.fetchall()
113+
]
114+
self._memory_cache[cache_key] = result
115+
return result
116+
117+
@staticmethod
118+
def compute_file_hash(path: str) -> str:
119+
h = hashlib.sha256(usedforsecurity=False)
120+
with Path(path).open("rb") as f:
121+
while True:
122+
chunk = f.read(8192)
123+
if not chunk:
124+
break
125+
h.update(chunk)
126+
return h.hexdigest()
127+
128+
def close(self) -> None:
129+
self.cur.close()
130+
self.connection.close()
131+
132+
40133
def discover_unit_tests(
41134
cfg: TestConfig, discover_only_these_tests: list[Path] | None = None
42135
) -> dict[str, list[FunctionCalledInTest]]:
43-
framework_strategies: dict[str, Callable] = {
44-
"pytest": discover_tests_pytest,
45-
"unittest": discover_tests_unittest,
46-
}
136+
framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest}
47137
strategy = framework_strategies.get(cfg.test_framework, None)
48138
if not strategy:
49139
error_message = f"Unsupported test framework: {cfg.test_framework}"
@@ -54,7 +144,7 @@ def discover_unit_tests(
54144

55145
def discover_tests_pytest(
56146
cfg: TestConfig, discover_only_these_tests: list[Path] | None = None
57-
) -> dict[str, list[FunctionCalledInTest]]:
147+
) -> dict[Path, list[FunctionCalledInTest]]:
58148
tests_root = cfg.tests_root
59149
project_root = cfg.project_root_path
60150

@@ -91,17 +181,15 @@ def discover_tests_pytest(
91181
)
92182

93183
elif 0 <= exitcode <= 5:
94-
logger.warning(
95-
f"Failed to collect tests. Pytest Exit code: {exitcode}={ExitCode(exitcode).name}"
96-
)
184+
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={ExitCode(exitcode).name}")
97185
else:
98186
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}")
99187
console.rule()
100188
else:
101189
logger.debug(f"Pytest collection exit code: {exitcode}")
102190
if pytest_rootdir is not None:
103191
cfg.tests_project_rootdir = Path(pytest_rootdir)
104-
file_to_test_map = defaultdict(list)
192+
file_to_test_map: dict[Path, list[FunctionCalledInTest]] = defaultdict(list)
105193
for test in tests:
106194
if "__replay_test" in test["test_file"]:
107195
test_type = TestType.REPLAY_TEST
@@ -116,10 +204,7 @@ def discover_tests_pytest(
116204
test_function=test["test_function"],
117205
test_type=test_type,
118206
)
119-
if (
120-
discover_only_these_tests
121-
and test_obj.test_file not in discover_only_these_tests
122-
):
207+
if discover_only_these_tests and test_obj.test_file not in discover_only_these_tests:
123208
continue
124209
file_to_test_map[test_obj.test_file].append(test_obj)
125210
# Within these test files, find the project functions they are referring to and return their names/locations
@@ -128,7 +213,7 @@ def discover_tests_pytest(
128213

129214
def discover_tests_unittest(
130215
cfg: TestConfig, discover_only_these_tests: list[str] | None = None
131-
) -> dict[str, list[FunctionCalledInTest]]:
216+
) -> dict[Path, list[FunctionCalledInTest]]:
132217
tests_root: Path = cfg.tests_root
133218
loader: unittest.TestLoader = unittest.TestLoader()
134219
tests: unittest.TestSuite = loader.discover(str(tests_root))
@@ -144,8 +229,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
144229
_test_module_path = Path(_test_module.replace(".", os.sep)).with_suffix(".py")
145230
_test_module_path = tests_root / _test_module_path
146231
if not _test_module_path.exists() or (
147-
discover_only_these_tests
148-
and str(_test_module_path) not in discover_only_these_tests
232+
discover_only_these_tests and str(_test_module_path) not in discover_only_these_tests
149233
):
150234
return None
151235
if "__replay_test" in str(_test_module_path):
@@ -172,9 +256,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
172256
if not hasattr(test, "_testMethodName") and hasattr(test, "_tests"):
173257
for test_2 in test._tests:
174258
if not hasattr(test_2, "_testMethodName"):
175-
logger.warning(
176-
f"Didn't find tests for {test_2}"
177-
) # it goes deeper?
259+
logger.warning(f"Didn't find tests for {test_2}") # it goes deeper?
178260
continue
179261
details = get_test_details(test_2)
180262
if details is not None:
@@ -195,19 +277,35 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
195277

196278

197279
def process_test_files(
198-
file_to_test_map: dict[str, list[TestsInFile]], cfg: TestConfig
280+
file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig
199281
) -> dict[str, list[FunctionCalledInTest]]:
200282
project_root_path = cfg.project_root_path
201283
test_framework = cfg.test_framework
284+
202285
function_to_test_map = defaultdict(set)
203286
jedi_project = jedi.Project(path=project_root_path)
204287
goto_cache = {}
288+
tests_cache = TestsCache()
205289

206-
with test_files_progress_bar(
207-
total=len(file_to_test_map), description="Processing test files"
208-
) as (progress, task_id):
209-
290+
with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
291+
progress,
292+
task_id,
293+
):
210294
for test_file, functions in file_to_test_map.items():
295+
file_hash = TestsCache.compute_file_hash(test_file)
296+
cached_tests = tests_cache.get_tests_for_file(str(test_file), file_hash)
297+
if cached_tests:
298+
self_cur = tests_cache.cur
299+
self_cur.execute(
300+
"SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?",
301+
(str(test_file), file_hash),
302+
)
303+
qualified_names = [row[0] for row in self_cur.fetchall()]
304+
for cached, qualified_name in zip(cached_tests, qualified_names):
305+
function_to_test_map[qualified_name].add(cached)
306+
progress.advance(task_id)
307+
continue
308+
211309
try:
212310
script = jedi.Script(path=test_file, project=jedi_project)
213311
test_functions = set()
@@ -216,12 +314,8 @@ def process_test_files(
216314
all_defs = script.get_names(all_scopes=True, definitions=True)
217315
all_names_top = script.get_names(all_scopes=True)
218316

219-
top_level_functions = {
220-
name.name: name for name in all_names_top if name.type == "function"
221-
}
222-
top_level_classes = {
223-
name.name: name for name in all_names_top if name.type == "class"
224-
}
317+
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
318+
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
225319
except Exception as e:
226320
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
227321
progress.advance(task_id)
@@ -230,36 +324,18 @@ def process_test_files(
230324
if test_framework == "pytest":
231325
for function in functions:
232326
if "[" in function.test_function:
233-
function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(
234-
function.test_function
235-
)[0]
236-
parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(
237-
function.test_function
238-
)[1]
327+
function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[0]
328+
parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[1]
239329
if function_name in top_level_functions:
240330
test_functions.add(
241-
TestFunction(
242-
function_name,
243-
function.test_class,
244-
parameters,
245-
function.test_type,
246-
)
331+
TestFunction(function_name, function.test_class, parameters, function.test_type)
247332
)
248333
elif function.test_function in top_level_functions:
249334
test_functions.add(
250-
TestFunction(
251-
function.test_function,
252-
function.test_class,
253-
None,
254-
function.test_type,
255-
)
256-
)
257-
elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(
258-
function.test_function
259-
):
260-
base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub(
261-
"", function.test_function
335+
TestFunction(function.test_function, function.test_class, None, function.test_type)
262336
)
337+
elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(function.test_function):
338+
base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub("", function.test_function)
263339
if base_name in top_level_functions:
264340
test_functions.add(
265341
TestFunction(
@@ -283,9 +359,7 @@ def process_test_files(
283359
and f".{matched_name}." in def_name.full_name
284360
):
285361
for function in functions_to_search:
286-
(is_parameterized, new_function, parameters) = (
287-
discover_parameters_unittest(function)
288-
)
362+
(is_parameterized, new_function, parameters) = discover_parameters_unittest(function)
289363

290364
if is_parameterized and new_function == def_name.name:
291365
test_functions.add(
@@ -329,9 +403,7 @@ def process_test_files(
329403
if cache_key in goto_cache:
330404
definition = goto_cache[cache_key]
331405
else:
332-
definition = name.goto(
333-
follow_imports=True, follow_builtin_imports=False
334-
)
406+
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
335407
goto_cache[cache_key] = definition
336408
except Exception as e:
337409
logger.debug(str(e))
@@ -358,11 +430,23 @@ def process_test_files(
358430
if test_framework == "unittest":
359431
scope_test_function += "_" + scope_parameters
360432

361-
full_name_without_module_prefix = definition[
362-
0
363-
].full_name.replace(definition[0].module_name + ".", "", 1)
433+
full_name_without_module_prefix = definition[0].full_name.replace(
434+
definition[0].module_name + ".", "", 1
435+
)
364436
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"
365437

438+
tests_cache.insert_test(
439+
file_path=str(test_file),
440+
file_hash=file_hash,
441+
qualified_name_with_modules_from_root=qualified_name_with_modules_from_root,
442+
function_name=scope,
443+
test_class=scope_test_class,
444+
test_function=scope_test_function,
445+
test_type=test_type,
446+
line_number=name.line,
447+
col_number=name.column,
448+
)
449+
366450
function_to_test_map[qualified_name_with_modules_from_root].add(
367451
FunctionCalledInTest(
368452
tests_in_file=TestsInFile(
@@ -371,12 +455,11 @@ def process_test_files(
371455
test_function=scope_test_function,
372456
test_type=test_type,
373457
),
374-
position=CodePosition(
375-
line_no=name.line, col_no=name.column
376-
),
458+
position=CodePosition(line_no=name.line, col_no=name.column),
377459
)
378460
)
379461

380462
progress.advance(task_id)
381463

464+
tests_cache.close()
382465
return {function: list(tests) for function, tests in function_to_test_map.items()}

0 commit comments

Comments
 (0)