Skip to content

Commit 905d34e

Browse files
committed
implement test caching
1 parent b6bbeec commit 905d34e

File tree

2 files changed

+154
-67
lines changed

2 files changed

+154
-67
lines changed

codeflash/code_utils/compat.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import sys
33
from pathlib import Path
44

5+
from platformdirs import user_config_dir
6+
57
# os-independent newline
68
# important for any user-facing output or files we write
79
# make sure to use this in f-strings e.g. f"some string{LF}"
@@ -12,3 +14,6 @@
1214
SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix()
1315

1416
IS_POSIX = os.name != "nt"
17+
18+
19+
codeflash_cache_dir = Path(user_config_dir(appname="codeflash", appauthor="codeflash-ai", ensure_exists=True))

codeflash/discovery/discover_unit_tests.py

Lines changed: 149 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

3+
import hashlib
34
import os
45
import pickle
56
import re
7+
import sqlite3
68
import subprocess
79
import unittest
810
from collections import defaultdict
@@ -15,7 +17,7 @@
1517

1618
from codeflash.cli_cmds.console import console, logger, test_files_progress_bar
1719
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
18-
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
20+
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_cache_dir
1921
from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType
2022

2123
if TYPE_CHECKING:
@@ -37,13 +39,100 @@ class TestFunction:
3739
FUNCTION_NAME_REGEX = re.compile(r"([^.]+)\.([a-zA-Z0-9_]+)$")
3840

3941

42+
class TestsCache:
43+
def __init__(self) -> None:
44+
self.connection = sqlite3.connect(codeflash_cache_dir / "tests_cache.db")
45+
self.cur = self.connection.cursor()
46+
47+
self.cur.execute(
48+
"""
49+
CREATE TABLE IF NOT EXISTS discovered_tests(
50+
file_path TEXT,
51+
file_hash TEXT,
52+
qualified_name_with_modules_from_root TEXT,
53+
function_name TEXT,
54+
test_class TEXT,
55+
test_function TEXT,
56+
test_type TEXT,
57+
line_number INTEGER,
58+
col_number INTEGER
59+
)
60+
"""
61+
)
62+
self.cur.execute(
63+
"""
64+
CREATE INDEX IF NOT EXISTS idx_discovered_tests_file_path_hash
65+
ON discovered_tests (file_path, file_hash)
66+
"""
67+
)
68+
self._memory_cache = {}
69+
70+
def insert_test(
71+
self,
72+
file_path: str,
73+
file_hash: str,
74+
qualified_name_with_modules_from_root: str,
75+
function_name: str,
76+
test_class: str,
77+
test_function: str,
78+
test_type: TestType,
79+
line_number: int,
80+
col_number: int,
81+
) -> None:
82+
test_type_value = test_type.value if hasattr(test_type, "value") else test_type
83+
self.cur.execute(
84+
"INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
85+
(
86+
file_path,
87+
file_hash,
88+
qualified_name_with_modules_from_root,
89+
function_name,
90+
test_class,
91+
test_function,
92+
test_type_value,
93+
line_number,
94+
col_number,
95+
),
96+
)
97+
self.connection.commit()
98+
99+
def get_tests_for_file(self, file_path: str, file_hash: str) -> list[FunctionCalledInTest]:
100+
cache_key = (file_path, file_hash)
101+
if cache_key in self._memory_cache:
102+
return self._memory_cache[cache_key]
103+
self.cur.execute("SELECT * FROM discovered_tests WHERE file_path = ? AND file_hash = ?", (file_path, file_hash))
104+
result = [
105+
FunctionCalledInTest(
106+
tests_in_file=TestsInFile(
107+
test_file=Path(row[0]), test_class=row[4], test_function=row[5], test_type=TestType(int(row[6]))
108+
),
109+
position=CodePosition(line_no=row[7], col_no=row[8]),
110+
)
111+
for row in self.cur.fetchall()
112+
]
113+
self._memory_cache[cache_key] = result
114+
return result
115+
116+
@staticmethod
117+
def compute_file_hash(path: str) -> str:
118+
h = hashlib.md5(usedforsecurity=False)
119+
with Path(path).open("rb") as f:
120+
while True:
121+
chunk = f.read(8192)
122+
if not chunk:
123+
break
124+
h.update(chunk)
125+
return h.hexdigest()
126+
127+
def close(self) -> None:
128+
self.cur.close()
129+
self.connection.close()
130+
131+
40132
def discover_unit_tests(
41133
cfg: TestConfig, discover_only_these_tests: list[Path] | None = None
42134
) -> dict[str, list[FunctionCalledInTest]]:
43-
framework_strategies: dict[str, Callable] = {
44-
"pytest": discover_tests_pytest,
45-
"unittest": discover_tests_unittest,
46-
}
135+
framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest}
47136
strategy = framework_strategies.get(cfg.test_framework, None)
48137
if not strategy:
49138
error_message = f"Unsupported test framework: {cfg.test_framework}"
@@ -54,7 +143,7 @@ def discover_unit_tests(
54143

55144
def discover_tests_pytest(
56145
cfg: TestConfig, discover_only_these_tests: list[Path] | None = None
57-
) -> dict[str, list[FunctionCalledInTest]]:
146+
) -> dict[Path, list[FunctionCalledInTest]]:
58147
tests_root = cfg.tests_root
59148
project_root = cfg.project_root_path
60149

@@ -91,17 +180,15 @@ def discover_tests_pytest(
91180
)
92181

93182
elif 0 <= exitcode <= 5:
94-
logger.warning(
95-
f"Failed to collect tests. Pytest Exit code: {exitcode}={ExitCode(exitcode).name}"
96-
)
183+
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={ExitCode(exitcode).name}")
97184
else:
98185
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}")
99186
console.rule()
100187
else:
101188
logger.debug(f"Pytest collection exit code: {exitcode}")
102189
if pytest_rootdir is not None:
103190
cfg.tests_project_rootdir = Path(pytest_rootdir)
104-
file_to_test_map = defaultdict(list)
191+
file_to_test_map: dict[Path, list[FunctionCalledInTest]] = defaultdict(list)
105192
for test in tests:
106193
if "__replay_test" in test["test_file"]:
107194
test_type = TestType.REPLAY_TEST
@@ -116,10 +203,7 @@ def discover_tests_pytest(
116203
test_function=test["test_function"],
117204
test_type=test_type,
118205
)
119-
if (
120-
discover_only_these_tests
121-
and test_obj.test_file not in discover_only_these_tests
122-
):
206+
if discover_only_these_tests and test_obj.test_file not in discover_only_these_tests:
123207
continue
124208
file_to_test_map[test_obj.test_file].append(test_obj)
125209
# Within these test files, find the project functions they are referring to and return their names/locations
@@ -128,7 +212,7 @@ def discover_tests_pytest(
128212

129213
def discover_tests_unittest(
130214
cfg: TestConfig, discover_only_these_tests: list[str] | None = None
131-
) -> dict[str, list[FunctionCalledInTest]]:
215+
) -> dict[Path, list[FunctionCalledInTest]]:
132216
tests_root: Path = cfg.tests_root
133217
loader: unittest.TestLoader = unittest.TestLoader()
134218
tests: unittest.TestSuite = loader.discover(str(tests_root))
@@ -144,8 +228,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
144228
_test_module_path = Path(_test_module.replace(".", os.sep)).with_suffix(".py")
145229
_test_module_path = tests_root / _test_module_path
146230
if not _test_module_path.exists() or (
147-
discover_only_these_tests
148-
and str(_test_module_path) not in discover_only_these_tests
231+
discover_only_these_tests and str(_test_module_path) not in discover_only_these_tests
149232
):
150233
return None
151234
if "__replay_test" in str(_test_module_path):
@@ -172,9 +255,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
172255
if not hasattr(test, "_testMethodName") and hasattr(test, "_tests"):
173256
for test_2 in test._tests:
174257
if not hasattr(test_2, "_testMethodName"):
175-
logger.warning(
176-
f"Didn't find tests for {test_2}"
177-
) # it goes deeper?
258+
logger.warning(f"Didn't find tests for {test_2}") # it goes deeper?
178259
continue
179260
details = get_test_details(test_2)
180261
if details is not None:
@@ -195,19 +276,35 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
195276

196277

197278
def process_test_files(
198-
file_to_test_map: dict[str, list[TestsInFile]], cfg: TestConfig
279+
file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig
199280
) -> dict[str, list[FunctionCalledInTest]]:
200281
project_root_path = cfg.project_root_path
201282
test_framework = cfg.test_framework
283+
202284
function_to_test_map = defaultdict(set)
203285
jedi_project = jedi.Project(path=project_root_path)
204286
goto_cache = {}
287+
tests_cache = TestsCache()
205288

206-
with test_files_progress_bar(
207-
total=len(file_to_test_map), description="Processing test files"
208-
) as (progress, task_id):
209-
289+
with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
290+
progress,
291+
task_id,
292+
):
210293
for test_file, functions in file_to_test_map.items():
294+
file_hash = TestsCache.compute_file_hash(test_file)
295+
cached_tests = tests_cache.get_tests_for_file(str(test_file), file_hash)
296+
if cached_tests:
297+
self_cur = tests_cache.cur
298+
self_cur.execute(
299+
"SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?",
300+
(str(test_file), file_hash),
301+
)
302+
qualified_names = [row[0] for row in self_cur.fetchall()]
303+
for cached, qualified_name in zip(cached_tests, qualified_names):
304+
function_to_test_map[qualified_name].add(cached)
305+
progress.advance(task_id)
306+
continue
307+
211308
try:
212309
script = jedi.Script(path=test_file, project=jedi_project)
213310
test_functions = set()
@@ -216,12 +313,8 @@ def process_test_files(
216313
all_defs = script.get_names(all_scopes=True, definitions=True)
217314
all_names_top = script.get_names(all_scopes=True)
218315

219-
top_level_functions = {
220-
name.name: name for name in all_names_top if name.type == "function"
221-
}
222-
top_level_classes = {
223-
name.name: name for name in all_names_top if name.type == "class"
224-
}
316+
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
317+
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
225318
except Exception as e:
226319
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
227320
progress.advance(task_id)
@@ -230,36 +323,18 @@ def process_test_files(
230323
if test_framework == "pytest":
231324
for function in functions:
232325
if "[" in function.test_function:
233-
function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(
234-
function.test_function
235-
)[0]
236-
parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(
237-
function.test_function
238-
)[1]
326+
function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[0]
327+
parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[1]
239328
if function_name in top_level_functions:
240329
test_functions.add(
241-
TestFunction(
242-
function_name,
243-
function.test_class,
244-
parameters,
245-
function.test_type,
246-
)
330+
TestFunction(function_name, function.test_class, parameters, function.test_type)
247331
)
248332
elif function.test_function in top_level_functions:
249333
test_functions.add(
250-
TestFunction(
251-
function.test_function,
252-
function.test_class,
253-
None,
254-
function.test_type,
255-
)
256-
)
257-
elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(
258-
function.test_function
259-
):
260-
base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub(
261-
"", function.test_function
334+
TestFunction(function.test_function, function.test_class, None, function.test_type)
262335
)
336+
elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(function.test_function):
337+
base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub("", function.test_function)
263338
if base_name in top_level_functions:
264339
test_functions.add(
265340
TestFunction(
@@ -283,9 +358,7 @@ def process_test_files(
283358
and f".{matched_name}." in def_name.full_name
284359
):
285360
for function in functions_to_search:
286-
(is_parameterized, new_function, parameters) = (
287-
discover_parameters_unittest(function)
288-
)
361+
(is_parameterized, new_function, parameters) = discover_parameters_unittest(function)
289362

290363
if is_parameterized and new_function == def_name.name:
291364
test_functions.add(
@@ -329,9 +402,7 @@ def process_test_files(
329402
if cache_key in goto_cache:
330403
definition = goto_cache[cache_key]
331404
else:
332-
definition = name.goto(
333-
follow_imports=True, follow_builtin_imports=False
334-
)
405+
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
335406
goto_cache[cache_key] = definition
336407
except Exception as e:
337408
logger.debug(str(e))
@@ -358,11 +429,23 @@ def process_test_files(
358429
if test_framework == "unittest":
359430
scope_test_function += "_" + scope_parameters
360431

361-
full_name_without_module_prefix = definition[
362-
0
363-
].full_name.replace(definition[0].module_name + ".", "", 1)
432+
full_name_without_module_prefix = definition[0].full_name.replace(
433+
definition[0].module_name + ".", "", 1
434+
)
364435
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"
365436

437+
tests_cache.insert_test(
438+
file_path=str(test_file),
439+
file_hash=file_hash,
440+
qualified_name_with_modules_from_root=qualified_name_with_modules_from_root,
441+
function_name=scope,
442+
test_class=scope_test_class,
443+
test_function=scope_test_function,
444+
test_type=test_type,
445+
line_number=name.line,
446+
col_number=name.column,
447+
)
448+
366449
function_to_test_map[qualified_name_with_modules_from_root].add(
367450
FunctionCalledInTest(
368451
tests_in_file=TestsInFile(
@@ -371,12 +454,11 @@ def process_test_files(
371454
test_function=scope_test_function,
372455
test_type=test_type,
373456
),
374-
position=CodePosition(
375-
line_no=name.line, col_no=name.column
376-
),
457+
position=CodePosition(line_no=name.line, col_no=name.column),
377458
)
378459
)
379460

380461
progress.advance(task_id)
381462

463+
tests_cache.close()
382464
return {function: list(tests) for function, tests in function_to_test_map.items()}

0 commit comments

Comments
 (0)