Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
d5fa1ef
tests cache
KRRT7 Sep 23, 2025
28eee4a
make ty happy
KRRT7 Sep 23, 2025
bbc630f
prevent progress bar artifact & respond to code review
KRRT7 Sep 23, 2025
c4b8a4b
formatting
KRRT7 Sep 23, 2025
caac361
Merge branch 'main' into test_cache_revival
KRRT7 Oct 9, 2025
fd34f8d
Merge branch 'main' into test_cache_revival
KRRT7 Oct 9, 2025
180c479
add project_root as a index key
KRRT7 Oct 9, 2025
a891be4
add unit tests for caching
KRRT7 Oct 9, 2025
924de86
Merge branch 'test_cache_revival' of https://github.com/codeflash-ai/…
KRRT7 Oct 9, 2025
119e8ec
Merge branch 'main' into test_cache_revival
KRRT7 Oct 9, 2025
49e44ee
formatting
KRRT7 Oct 9, 2025
ea9878d
loosen E2E init
KRRT7 Oct 9, 2025
a0c333c
Update discover_unit_tests.py
KRRT7 Oct 9, 2025
bdc062a
Merge branch 'main' of github.com:codeflash-ai/codeflash into test_ca…
Oct 10, 2025
b210ba4
Merge branch 'main' into test_cache_revival
KRRT7 Oct 10, 2025
1445c38
Optimize TestsCache.compute_file_hash
codeflash-ai[bot] Oct 10, 2025
c3e2ec2
it's a pathy objectey
KRRT7 Oct 10, 2025
fd64a22
Merge pull request #810 from codeflash-ai/codeflash/optimize-pr753-20…
KRRT7 Oct 10, 2025
bb982cb
use path objects consistently
KRRT7 Oct 10, 2025
d8dd14d
formatter
KRRT7 Oct 10, 2025
b243158
add schema to testcache
KRRT7 Oct 10, 2025
2bead9c
Merge branch 'main' into test_cache_revival
KRRT7 Oct 14, 2025
1b58dd1
Merge branch 'main' into test_cache_revival
KRRT7 Oct 14, 2025
d5cf24b
Merge branch 'test_cache_revival' of github.com:codeflash-ai/codeflas…
Oct 15, 2025
e187f98
Merge branch 'main' of github.com:codeflash-ai/codeflash into test_ca…
Oct 15, 2025
882a2e0
lsp log for discovering tests loading message
Oct 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion codeflash/code_utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,9 @@ def get_all_function_names(code: str) -> tuple[bool, list[str]]:
return True, function_names


def get_run_tmp_file(file_path: Path) -> Path:
def get_run_tmp_file(file_path: Path | str) -> Path:
if isinstance(file_path, str):
file_path = Path(file_path)
if not hasattr(get_run_tmp_file, "tmpdir"):
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
return Path(get_run_tmp_file.tmpdir.name) / file_path
Expand Down
171 changes: 119 additions & 52 deletions codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,39 @@ class TestFunction:


class TestsCache:
def __init__(self) -> None:
SCHEMA_VERSION = 1 # Increment this when schema changes

def __init__(self, project_root_path: str | Path) -> None:
self.project_root_path = Path(project_root_path).resolve().as_posix()
self.connection = sqlite3.connect(codeflash_cache_db)
self.cur = self.connection.cursor()

self.cur.execute(
"""
CREATE TABLE IF NOT EXISTS schema_version(
version INTEGER PRIMARY KEY
)
"""
)

self.cur.execute("SELECT version FROM schema_version")
result = self.cur.fetchone()
current_version = result[0] if result else None

if current_version != self.SCHEMA_VERSION:
logger.debug(
f"Schema version mismatch (current: {current_version}, expected: {self.SCHEMA_VERSION}). Recreating tables."
)
self.cur.execute("DROP TABLE IF EXISTS discovered_tests")
self.cur.execute("DROP INDEX IF EXISTS idx_discovered_tests_project_file_path_hash")
self.cur.execute("DELETE FROM schema_version")
self.cur.execute("INSERT INTO schema_version (version) VALUES (?)", (self.SCHEMA_VERSION,))
self.connection.commit()

self.cur.execute(
"""
CREATE TABLE IF NOT EXISTS discovered_tests(
project_root_path TEXT,
Copy link
Contributor

@mohammedahmed18 mohammedahmed18 Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this won't create the project_root_path column if the table already created before, you should have a separate sql for adding the new column

2025-10-10_18-06

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we actually need some kind of migration engine later, for these type of changes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huh, that database should be empty, if not, I think only our team members would have it, but yes let me come up with a fix

file_path TEXT,
file_hash TEXT,
qualified_name_with_modules_from_root TEXT,
Expand All @@ -88,11 +114,12 @@ def __init__(self) -> None:
)
self.cur.execute(
"""
CREATE INDEX IF NOT EXISTS idx_discovered_tests_file_path_hash
ON discovered_tests (file_path, file_hash)
CREATE INDEX IF NOT EXISTS idx_discovered_tests_project_file_path_hash
ON discovered_tests (project_root_path, file_path, file_hash)
"""
)
self._memory_cache = {}

self.memory_cache = {}

def insert_test(
self,
Expand All @@ -108,8 +135,9 @@ def insert_test(
) -> None:
test_type_value = test_type.value if hasattr(test_type, "value") else test_type
self.cur.execute(
"INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
"INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
self.project_root_path,
file_path,
file_hash,
qualified_name_with_modules_from_root,
Expand All @@ -123,32 +151,48 @@ def insert_test(
)
self.connection.commit()

def get_tests_for_file(self, file_path: str, file_hash: str) -> list[FunctionCalledInTest]:
cache_key = (file_path, file_hash)
if cache_key in self._memory_cache:
return self._memory_cache[cache_key]
self.cur.execute("SELECT * FROM discovered_tests WHERE file_path = ? AND file_hash = ?", (file_path, file_hash))
result = [
FunctionCalledInTest(
def get_function_to_test_map_for_file(
self, file_path: str, file_hash: str
) -> dict[str, set[FunctionCalledInTest]] | None:
cache_key = (self.project_root_path, file_path, file_hash)
if cache_key in self.memory_cache:
return self.memory_cache[cache_key]

self.cur.execute(
"SELECT * FROM discovered_tests WHERE project_root_path = ? AND file_path = ? AND file_hash = ?",
(self.project_root_path, file_path, file_hash),
)
rows = self.cur.fetchall()
if not rows:
return None

function_to_test_map = defaultdict(set)

for row in rows:
qualified_name_with_modules_from_root = row[3]
function_called_in_test = FunctionCalledInTest(
tests_in_file=TestsInFile(
test_file=Path(row[0]), test_class=row[4], test_function=row[5], test_type=TestType(int(row[6]))
test_file=Path(row[1]), test_class=row[5], test_function=row[6], test_type=TestType(int(row[7]))
),
position=CodePosition(line_no=row[7], col_no=row[8]),
position=CodePosition(line_no=row[8], col_no=row[9]),
)
for row in self.cur.fetchall()
]
self._memory_cache[cache_key] = result
function_to_test_map[qualified_name_with_modules_from_root].add(function_called_in_test)

result = dict(function_to_test_map)
self.memory_cache[cache_key] = result
return result

@staticmethod
def compute_file_hash(path: str) -> str:
def compute_file_hash(path: Path) -> str:
h = hashlib.sha256(usedforsecurity=False)
with Path(path).open("rb") as f:
with path.open("rb", buffering=0) as f:
buf = bytearray(8192)
mv = memoryview(buf)
while True:
chunk = f.read(8192)
if not chunk:
n = f.readinto(mv)
if n == 0:
break
h.update(chunk)
h.update(mv[:n])
return h.hexdigest()

def close(self) -> None:
Expand Down Expand Up @@ -394,7 +438,7 @@ def discover_tests_pytest(
cfg: TestConfig,
discover_only_these_tests: list[Path] | None = None,
functions_to_optimize: list[FunctionToOptimize] | None = None,
) -> tuple[dict[str, set[FunctionCalledInTest]], int]:
) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]:
tests_root = cfg.tests_root
project_root = cfg.project_root_path

Expand Down Expand Up @@ -432,9 +476,11 @@ def discover_tests_pytest(
f"Failed to collect tests. Pytest Exit code: {exitcode}={PytestExitCode(exitcode).name}\n {error_section}"
)
if "ModuleNotFoundError" in result.stdout:
match = ImportErrorPattern.search(result.stdout).group()
panel = Panel(Text.from_markup(f"⚠️ {match} ", style="bold red"), expand=False)
console.print(panel)
match = ImportErrorPattern.search(result.stdout)
if match:
error_message = match.group()
panel = Panel(Text.from_markup(f"⚠️ {error_message} ", style="bold red"), expand=False)
console.print(panel)

elif 0 <= exitcode <= 5:
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={PytestExitCode(exitcode).name}")
Expand Down Expand Up @@ -469,13 +515,13 @@ def discover_tests_pytest(

def discover_tests_unittest(
cfg: TestConfig,
discover_only_these_tests: list[str] | None = None,
discover_only_these_tests: list[Path] | None = None,
functions_to_optimize: list[FunctionToOptimize] | None = None,
) -> tuple[dict[str, set[FunctionCalledInTest]], int]:
) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]:
tests_root: Path = cfg.tests_root
loader: unittest.TestLoader = unittest.TestLoader()
tests: unittest.TestSuite = loader.discover(str(tests_root))
file_to_test_map: defaultdict[str, list[TestsInFile]] = defaultdict(list)
file_to_test_map: defaultdict[Path, list[TestsInFile]] = defaultdict(list)

def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
_test_function, _test_module, _test_suite_name = (
Expand All @@ -487,7 +533,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
_test_module_path = Path(_test_module.replace(".", os.sep)).with_suffix(".py")
_test_module_path = tests_root / _test_module_path
if not _test_module_path.exists() or (
discover_only_these_tests and str(_test_module_path) not in discover_only_these_tests
discover_only_these_tests and _test_module_path not in discover_only_these_tests
):
return None
if "__replay_test" in str(_test_module_path):
Expand All @@ -497,10 +543,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
else:
test_type = TestType.EXISTING_UNIT_TEST
return TestsInFile(
test_file=str(_test_module_path),
test_function=_test_function,
test_type=test_type,
test_class=_test_suite_name,
test_file=_test_module_path, test_function=_test_function, test_type=test_type, test_class=_test_suite_name
)

for _test_suite in tests._tests:
Expand All @@ -518,18 +561,18 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
continue
details = get_test_details(test_2)
if details is not None:
file_to_test_map[str(details.test_file)].append(details)
file_to_test_map[details.test_file].append(details)
else:
details = get_test_details(test)
if details is not None:
file_to_test_map[str(details.test_file)].append(details)
file_to_test_map[details.test_file].append(details)
return process_test_files(file_to_test_map, cfg, functions_to_optimize)


def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | None]:
function_name = function_name.split("_")
if len(function_name) > 1 and function_name[-1].isdigit():
return True, "_".join(function_name[:-1]), function_name[-1]
function_parts = function_name.split("_")
if len(function_parts) > 1 and function_parts[-1].isdigit():
return True, "_".join(function_parts[:-1]), function_parts[-1]

return False, function_name, None

Expand All @@ -538,7 +581,7 @@ def process_test_files(
file_to_test_map: dict[Path, list[TestsInFile]],
cfg: TestConfig,
functions_to_optimize: list[FunctionToOptimize] | None = None,
) -> tuple[dict[str, set[FunctionCalledInTest]], int]:
) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]:
import jedi

project_root_path = cfg.project_root_path
Expand All @@ -553,29 +596,39 @@ def process_test_files(
num_discovered_replay_tests = 0
jedi_project = jedi.Project(path=project_root_path)

tests_cache = TestsCache(project_root_path)
logger.info("!lsp|Discovering tests and processing unit tests")
with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
progress,
task_id,
):
for test_file, functions in file_to_test_map.items():
file_hash = TestsCache.compute_file_hash(test_file)

cached_function_to_test_map = tests_cache.get_function_to_test_map_for_file(str(test_file), file_hash)

if cfg.use_cache and cached_function_to_test_map:
for qualified_name, test_set in cached_function_to_test_map.items():
function_to_test_map[qualified_name].update(test_set)

for function_called_in_test in test_set:
if function_called_in_test.tests_in_file.test_type == TestType.REPLAY_TEST:
num_discovered_replay_tests += 1
num_discovered_tests += 1

progress.advance(task_id)
continue
try:
script = jedi.Script(path=test_file, project=jedi_project)
test_functions = set()

# Single call to get all names with references and definitions
all_names = script.get_names(all_scopes=True, references=True, definitions=True)
all_names = script.get_names(all_scopes=True, references=True)
all_defs = script.get_names(all_scopes=True, definitions=True)
all_names_top = script.get_names(all_scopes=True)

# Filter once and create lookup dictionaries
top_level_functions = {}
top_level_classes = {}
all_defs = []
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}

for name in all_names:
if name.type == "function":
top_level_functions[name.name] = name
all_defs.append(name)
elif name.type == "class":
top_level_classes[name.name] = name
except Exception as e:
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
progress.advance(task_id)
Expand Down Expand Up @@ -697,6 +750,18 @@ def process_test_files(
position=CodePosition(line_no=name.line, col_no=name.column),
)
)
tests_cache.insert_test(
file_path=str(test_file),
file_hash=file_hash,
qualified_name_with_modules_from_root=qualified_name_with_modules_from_root,
function_name=scope,
test_class=test_func.test_class or "",
test_function=scope_test_function,
test_type=test_func.test_type,
line_number=name.line,
col_number=name.column,
)

if test_func.test_type == TestType.REPLAY_TEST:
num_discovered_replay_tests += 1

Expand All @@ -707,4 +772,6 @@ def process_test_files(

progress.advance(task_id)

tests_cache.close()

return dict(function_to_test_map), num_discovered_tests, num_discovered_replay_tests
18 changes: 9 additions & 9 deletions codeflash/optimization/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,15 +239,15 @@ def discover_tests(
from codeflash.discovery.discover_unit_tests import discover_unit_tests

console.rule()
with progress_bar("Discovering existing function tests..."):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the progress_bar will also show the loading text animation in the extension, I think we should keep it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Screenshot 2025-10-09 at 3 41 12 PM there is still a progress bar, I just moved it elsewhere since I was actually seeing 2 progress bar artifacts

the new progress bar is implemented in codeflash/discovery/discover_unit_tests.py::process_test_files:577

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Screen.Recording.2025-10-09.at.3.43.36.PM.mov

this is the artifact

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Screen.Recording.2025-10-09.at.3.45.08.PM.mov

no artifact

start_time = time.time()
function_to_tests, num_discovered_tests, num_discovered_replay_tests = discover_unit_tests(
self.test_cfg, file_to_funcs_to_optimize=file_to_funcs_to_optimize
)
console.rule()
logger.info(
f"Discovered {num_discovered_tests} existing unit tests and {num_discovered_replay_tests} replay tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}"
)
start_time = time.time()
logger.info("lsp,loading|Discovering existing function tests...")
function_to_tests, num_discovered_tests, num_discovered_replay_tests = discover_unit_tests(
self.test_cfg, file_to_funcs_to_optimize=file_to_funcs_to_optimize
)
console.rule()
logger.info(
f"Discovered {num_discovered_tests} existing unit tests and {num_discovered_replay_tests} replay tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}"
)
console.rule()
ph("cli-optimize-discovered-tests", {"num_tests": num_discovered_tests})
return function_to_tests, num_discovered_tests
Expand Down
1 change: 1 addition & 0 deletions codeflash/verification/verification_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,4 @@ class TestConfig:
concolic_test_root_dir: Optional[Path] = None
pytest_cmd: str = "pytest"
benchmark_tests_root: Optional[Path] = None
use_cache: bool = True
2 changes: 1 addition & 1 deletion tests/scripts/end_to_end_test_init_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ def run_test(expected_improvement_pct: int) -> bool:


if __name__ == "__main__":
exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 5))))
exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 10))))
Loading
Loading