Skip to content

Commit 948bfed

Browse files
authored
Merge pull request #1852 from codeflash-ai/cf-1846-port-perf-improvements
perf: cache jedi project, batch test cache writes, fix Windows relative_to bug
2 parents 506eb44 + 7cf183e commit 948bfed

File tree

4 files changed

+217
-114
lines changed

4 files changed

+217
-114
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 71 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ def __init__(self, project_root_path: Path) -> None:
136136
)
137137

138138
self.memory_cache = {}
139+
self.pending_rows: list[tuple[str, str, str, str, str, str, int | TestType, int, int]] = []
140+
self.writes_enabled = True
139141

140142
def insert_test(
141143
self,
@@ -150,10 +152,8 @@ def insert_test(
150152
col_number: int,
151153
) -> None:
152154
test_type_value = test_type.value if hasattr(test_type, "value") else test_type
153-
self.cur.execute(
154-
"INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
155+
self.pending_rows.append(
155156
(
156-
self.project_root_path,
157157
file_path,
158158
file_hash,
159159
qualified_name_with_modules_from_root,
@@ -163,9 +163,26 @@ def insert_test(
163163
test_type_value,
164164
line_number,
165165
col_number,
166-
),
166+
)
167167
)
168-
self.connection.commit()
168+
169+
def flush(self) -> None:
170+
if not self.pending_rows:
171+
return
172+
if not self.writes_enabled:
173+
self.pending_rows.clear()
174+
return
175+
try:
176+
self.cur.executemany(
177+
"INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
178+
[(self.project_root_path, *row) for row in self.pending_rows],
179+
)
180+
self.connection.commit()
181+
except sqlite3.OperationalError as e:
182+
logger.debug(f"Failed to persist discovered test cache, disabling cache writes: {e}")
183+
self.writes_enabled = False
184+
finally:
185+
self.pending_rows.clear()
169186

170187
def get_function_to_test_map_for_file(
171188
self, file_path: str, file_hash: str
@@ -212,6 +229,7 @@ def compute_file_hash(path: Path) -> str:
212229
return h.hexdigest()
213230

214231
def close(self) -> None:
232+
self.flush()
215233
self.cur.close()
216234
self.connection.close()
217235

@@ -849,6 +867,10 @@ def process_test_files(
849867
function_to_test_map = defaultdict(set)
850868
num_discovered_tests = 0
851869
num_discovered_replay_tests = 0
870+
functions_to_optimize_by_name: dict[str, list[FunctionToOptimize]] = defaultdict(list)
871+
if functions_to_optimize:
872+
for function_to_optimize in functions_to_optimize:
873+
functions_to_optimize_by_name[function_to_optimize.function_name].append(function_to_optimize)
852874

853875
# Set up sys_path for Jedi to resolve imports correctly
854876
import sys
@@ -891,8 +913,8 @@ def process_test_files(
891913
test_functions = set()
892914

893915
all_names = script.get_names(all_scopes=True, references=True)
894-
all_defs = script.get_names(all_scopes=True, definitions=True)
895916
all_names_top = script.get_names(all_scopes=True)
917+
all_defs = [name for name in all_names if name.is_definition()]
896918

897919
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
898920
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
@@ -967,10 +989,9 @@ def process_test_files(
967989

968990
test_function_names_set = set(test_functions_by_name.keys())
969991
relevant_names = []
970-
971-
names_with_full_name = [name for name in all_names if name.full_name is not None]
972-
973-
for name in names_with_full_name:
992+
for name in all_names:
993+
if name.full_name is None:
994+
continue
974995
match = FUNCTION_NAME_REGEX.search(name.full_name)
975996
if match and match.group(1) in test_function_names_set:
976997
relevant_names.append((name, match.group(1)))
@@ -985,56 +1006,49 @@ def process_test_files(
9851006
if not definition or definition[0].type != "function":
9861007
# Fallback: Try to match against functions_to_optimize when Jedi can't resolve
9871008
# This handles cases where Jedi fails with pytest fixtures
988-
if functions_to_optimize and name.name:
989-
for func_to_opt in functions_to_optimize:
990-
# Check if this unresolved name matches a function we're looking for
991-
if func_to_opt.function_name == name.name:
992-
# Check if the test file imports the class/module containing this function
993-
qualified_name_with_modules = func_to_opt.qualified_name_with_modules_from_root(
994-
project_root_path
995-
)
1009+
if functions_to_optimize_by_name and name.name:
1010+
for func_to_opt in functions_to_optimize_by_name.get(name.name, []):
1011+
qualified_name_with_modules = func_to_opt.qualified_name_with_modules_from_root(
1012+
project_root_path
1013+
)
9961014

997-
# Only add if this test actually tests the function we're optimizing
998-
for test_func in test_functions_by_name[scope]:
999-
if test_func.parameters is not None:
1000-
if test_framework == "pytest":
1001-
scope_test_function = (
1002-
f"{test_func.function_name}[{test_func.parameters}]"
1003-
)
1004-
else: # unittest
1005-
scope_test_function = (
1006-
f"{test_func.function_name}_{test_func.parameters}"
1007-
)
1008-
else:
1009-
scope_test_function = test_func.function_name
1010-
1011-
function_to_test_map[qualified_name_with_modules].add(
1012-
FunctionCalledInTest(
1013-
tests_in_file=TestsInFile(
1014-
test_file=test_file,
1015-
test_class=test_func.test_class,
1016-
test_function=scope_test_function,
1017-
test_type=test_func.test_type,
1018-
),
1019-
position=CodePosition(line_no=name.line, col_no=name.column),
1020-
)
1021-
)
1022-
tests_cache.insert_test(
1023-
file_path=str(test_file),
1024-
file_hash=file_hash,
1025-
qualified_name_with_modules_from_root=qualified_name_with_modules,
1026-
function_name=scope,
1027-
test_class=test_func.test_class or "",
1028-
test_function=scope_test_function,
1029-
test_type=test_func.test_type,
1030-
line_number=name.line,
1031-
col_number=name.column,
1015+
# Only add if this test actually tests the function we're optimizing
1016+
for test_func in test_functions_by_name[scope]:
1017+
if test_func.parameters is not None:
1018+
if test_framework == "pytest":
1019+
scope_test_function = f"{test_func.function_name}[{test_func.parameters}]"
1020+
else: # unittest
1021+
scope_test_function = f"{test_func.function_name}_{test_func.parameters}"
1022+
else:
1023+
scope_test_function = test_func.function_name
1024+
1025+
function_to_test_map[qualified_name_with_modules].add(
1026+
FunctionCalledInTest(
1027+
tests_in_file=TestsInFile(
1028+
test_file=test_file,
1029+
test_class=test_func.test_class,
1030+
test_function=scope_test_function,
1031+
test_type=test_func.test_type,
1032+
),
1033+
position=CodePosition(line_no=name.line, col_no=name.column),
10321034
)
1035+
)
1036+
tests_cache.insert_test(
1037+
file_path=str(test_file),
1038+
file_hash=file_hash,
1039+
qualified_name_with_modules_from_root=qualified_name_with_modules,
1040+
function_name=scope,
1041+
test_class=test_func.test_class or "",
1042+
test_function=scope_test_function,
1043+
test_type=test_func.test_type,
1044+
line_number=name.line,
1045+
col_number=name.column,
1046+
)
10331047

1034-
if test_func.test_type == TestType.REPLAY_TEST:
1035-
num_discovered_replay_tests += 1
1048+
if test_func.test_type == TestType.REPLAY_TEST:
1049+
num_discovered_replay_tests += 1
10361050

1037-
num_discovered_tests += 1
1051+
num_discovered_tests += 1
10381052
continue
10391053
definition_obj = definition[0]
10401054
definition_path = str(definition_obj.module_path)
@@ -1090,6 +1104,7 @@ def process_test_files(
10901104
logger.debug(str(e))
10911105
continue
10921106

1107+
tests_cache.flush()
10931108
progress.advance(task_id)
10941109

10951110
tests_cache.close()

0 commit comments

Comments
 (0)