Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
d5fa1ef
tests cache
KRRT7 Sep 23, 2025
28eee4a
make ty happy
KRRT7 Sep 23, 2025
bbc630f
prevent progress bar artifact & respond to code review
KRRT7 Sep 23, 2025
c4b8a4b
formatting
KRRT7 Sep 23, 2025
caac361
Merge branch 'main' into test_cache_revival
KRRT7 Oct 9, 2025
fd34f8d
Merge branch 'main' into test_cache_revival
KRRT7 Oct 9, 2025
180c479
add project_root as a index key
KRRT7 Oct 9, 2025
a891be4
add unit tests for caching
KRRT7 Oct 9, 2025
924de86
Merge branch 'test_cache_revival' of https://github.com/codeflash-ai/…
KRRT7 Oct 9, 2025
119e8ec
Merge branch 'main' into test_cache_revival
KRRT7 Oct 9, 2025
49e44ee
formatting
KRRT7 Oct 9, 2025
ea9878d
loosen E2E init
KRRT7 Oct 9, 2025
a0c333c
Update discover_unit_tests.py
KRRT7 Oct 9, 2025
bdc062a
Merge branch 'main' of github.com:codeflash-ai/codeflash into test_ca…
Oct 10, 2025
b210ba4
Merge branch 'main' into test_cache_revival
KRRT7 Oct 10, 2025
1445c38
Optimize TestsCache.compute_file_hash
codeflash-ai[bot] Oct 10, 2025
c3e2ec2
it's a pathy objectey
KRRT7 Oct 10, 2025
fd64a22
Merge pull request #810 from codeflash-ai/codeflash/optimize-pr753-20…
KRRT7 Oct 10, 2025
bb982cb
use path objects consistently
KRRT7 Oct 10, 2025
d8dd14d
formatter
KRRT7 Oct 10, 2025
b243158
add schema to testcache
KRRT7 Oct 10, 2025
2bead9c
Merge branch 'main' into test_cache_revival
KRRT7 Oct 14, 2025
1b58dd1
Merge branch 'main' into test_cache_revival
KRRT7 Oct 14, 2025
d5cf24b
Merge branch 'test_cache_revival' of github.com:codeflash-ai/codeflas…
Oct 15, 2025
e187f98
Merge branch 'main' of github.com:codeflash-ai/codeflash into test_ca…
Oct 15, 2025
882a2e0
lsp log for discovering tests loading message
Oct 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/e2e-init-optimization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
COLUMNS: 110
MAX_RETRIES: 3
RETRY_DELAY: 5
EXPECTED_IMPROVEMENT_PCT: 30
EXPECTED_IMPROVEMENT_PCT: 10
CODEFLASH_END_TO_END: 1
steps:
- name: 🛎️ Checkout
Expand Down
4 changes: 3 additions & 1 deletion codeflash/code_utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,9 @@ def get_all_function_names(code: str) -> tuple[bool, list[str]]:
return True, function_names


def get_run_tmp_file(file_path: Path) -> Path:
def get_run_tmp_file(file_path: Path | str) -> Path:
if isinstance(file_path, str):
file_path = Path(file_path)
if not hasattr(get_run_tmp_file, "tmpdir"):
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
return Path(get_run_tmp_file.tmpdir.name) / file_path
Expand Down
123 changes: 83 additions & 40 deletions codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,14 @@ class TestFunction:


class TestsCache:
def __init__(self) -> None:
def __init__(self, project_root_path: str | Path) -> None:
self.project_root_path = Path(project_root_path).resolve().as_posix()
self.connection = sqlite3.connect(codeflash_cache_db)
self.cur = self.connection.cursor()

self.cur.execute(
"""
CREATE TABLE IF NOT EXISTS discovered_tests(
project_root_path TEXT,
Copy link
Contributor

@mohammedahmed18 mohammedahmed18 Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this won't create the project_root_path column if the table already created before, you should have a separate sql for adding the new column

2025-10-10_18-06

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we actually need some kind of migration engine later, for these type of changes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huh, that database should be empty, if not, I think only our team members would have it, but yes let me come up with a fix

file_path TEXT,
file_hash TEXT,
qualified_name_with_modules_from_root TEXT,
Expand All @@ -88,11 +89,12 @@ def __init__(self) -> None:
)
self.cur.execute(
"""
CREATE INDEX IF NOT EXISTS idx_discovered_tests_file_path_hash
ON discovered_tests (file_path, file_hash)
CREATE INDEX IF NOT EXISTS idx_discovered_tests_project_file_path_hash
ON discovered_tests (project_root_path, file_path, file_hash)
"""
)
self._memory_cache = {}

self.memory_cache = {}

def insert_test(
self,
Expand All @@ -108,8 +110,9 @@ def insert_test(
) -> None:
test_type_value = test_type.value if hasattr(test_type, "value") else test_type
self.cur.execute(
"INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
"INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
self.project_root_path,
file_path,
file_hash,
qualified_name_with_modules_from_root,
Expand All @@ -123,25 +126,39 @@ def insert_test(
)
self.connection.commit()

def get_tests_for_file(self, file_path: str, file_hash: str) -> list[FunctionCalledInTest]:
cache_key = (file_path, file_hash)
if cache_key in self._memory_cache:
return self._memory_cache[cache_key]
self.cur.execute("SELECT * FROM discovered_tests WHERE file_path = ? AND file_hash = ?", (file_path, file_hash))
result = [
FunctionCalledInTest(
def get_function_to_test_map_for_file(
self, file_path: str, file_hash: str
) -> dict[str, set[FunctionCalledInTest]] | None:
cache_key = (self.project_root_path, file_path, file_hash)
if cache_key in self.memory_cache:
return self.memory_cache[cache_key]

self.cur.execute(
"SELECT * FROM discovered_tests WHERE project_root_path = ? AND file_path = ? AND file_hash = ?",
(self.project_root_path, file_path, file_hash),
)
rows = self.cur.fetchall()
if not rows:
return None

function_to_test_map = defaultdict(set)

for row in rows:
qualified_name_with_modules_from_root = row[3]
function_called_in_test = FunctionCalledInTest(
tests_in_file=TestsInFile(
test_file=Path(row[0]), test_class=row[4], test_function=row[5], test_type=TestType(int(row[6]))
test_file=Path(row[1]), test_class=row[5], test_function=row[6], test_type=TestType(int(row[7]))
),
position=CodePosition(line_no=row[7], col_no=row[8]),
position=CodePosition(line_no=row[8], col_no=row[9]),
)
for row in self.cur.fetchall()
]
self._memory_cache[cache_key] = result
function_to_test_map[qualified_name_with_modules_from_root].add(function_called_in_test)

result = dict(function_to_test_map)
self.memory_cache[cache_key] = result
return result

@staticmethod
def compute_file_hash(path: str) -> str:
def compute_file_hash(path: str | Path) -> str:
h = hashlib.sha256(usedforsecurity=False)
with Path(path).open("rb") as f:
while True:
Expand Down Expand Up @@ -394,7 +411,7 @@ def discover_tests_pytest(
cfg: TestConfig,
discover_only_these_tests: list[Path] | None = None,
functions_to_optimize: list[FunctionToOptimize] | None = None,
) -> tuple[dict[str, set[FunctionCalledInTest]], int]:
) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]:
tests_root = cfg.tests_root
project_root = cfg.project_root_path

Expand Down Expand Up @@ -432,9 +449,11 @@ def discover_tests_pytest(
f"Failed to collect tests. Pytest Exit code: {exitcode}={PytestExitCode(exitcode).name}\n {error_section}"
)
if "ModuleNotFoundError" in result.stdout:
match = ImportErrorPattern.search(result.stdout).group()
panel = Panel(Text.from_markup(f"⚠️ {match} ", style="bold red"), expand=False)
console.print(panel)
match = ImportErrorPattern.search(result.stdout)
if match:
error_message = match.group()
panel = Panel(Text.from_markup(f"⚠️ {error_message} ", style="bold red"), expand=False)
console.print(panel)

elif 0 <= exitcode <= 5:
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={PytestExitCode(exitcode).name}")
Expand Down Expand Up @@ -471,7 +490,7 @@ def discover_tests_unittest(
cfg: TestConfig,
discover_only_these_tests: list[str] | None = None,
functions_to_optimize: list[FunctionToOptimize] | None = None,
) -> tuple[dict[str, set[FunctionCalledInTest]], int]:
) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]:
tests_root: Path = cfg.tests_root
loader: unittest.TestLoader = unittest.TestLoader()
tests: unittest.TestSuite = loader.discover(str(tests_root))
Expand Down Expand Up @@ -527,9 +546,9 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:


def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | None]:
function_name = function_name.split("_")
if len(function_name) > 1 and function_name[-1].isdigit():
return True, "_".join(function_name[:-1]), function_name[-1]
function_parts = function_name.split("_")
if len(function_parts) > 1 and function_parts[-1].isdigit():
return True, "_".join(function_parts[:-1]), function_parts[-1]

return False, function_name, None

Expand All @@ -538,7 +557,7 @@ def process_test_files(
file_to_test_map: dict[Path, list[TestsInFile]],
cfg: TestConfig,
functions_to_optimize: list[FunctionToOptimize] | None = None,
) -> tuple[dict[str, set[FunctionCalledInTest]], int]:
) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]:
import jedi

project_root_path = cfg.project_root_path
Expand All @@ -553,29 +572,39 @@ def process_test_files(
num_discovered_replay_tests = 0
jedi_project = jedi.Project(path=project_root_path)

tests_cache = TestsCache(project_root_path)
logger.info("!lsp|Discovering tests and processing unit tests")
with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
progress,
task_id,
):
for test_file, functions in file_to_test_map.items():
file_hash = TestsCache.compute_file_hash(test_file)

cached_function_to_test_map = tests_cache.get_function_to_test_map_for_file(str(test_file), file_hash)

if cfg.use_cache and cached_function_to_test_map:
for qualified_name, test_set in cached_function_to_test_map.items():
function_to_test_map[qualified_name].update(test_set)

for function_called_in_test in test_set:
if function_called_in_test.tests_in_file.test_type == TestType.REPLAY_TEST:
num_discovered_replay_tests += 1
num_discovered_tests += 1

progress.advance(task_id)
continue
try:
script = jedi.Script(path=test_file, project=jedi_project)
test_functions = set()

# Single call to get all names with references and definitions
all_names = script.get_names(all_scopes=True, references=True, definitions=True)
all_names = script.get_names(all_scopes=True, references=True)
all_defs = script.get_names(all_scopes=True, definitions=True)
all_names_top = script.get_names(all_scopes=True)

# Filter once and create lookup dictionaries
top_level_functions = {}
top_level_classes = {}
all_defs = []
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}

for name in all_names:
if name.type == "function":
top_level_functions[name.name] = name
all_defs.append(name)
elif name.type == "class":
top_level_classes[name.name] = name
except Exception as e:
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
progress.advance(task_id)
Expand Down Expand Up @@ -697,6 +726,18 @@ def process_test_files(
position=CodePosition(line_no=name.line, col_no=name.column),
)
)
tests_cache.insert_test(
file_path=str(test_file),
file_hash=file_hash,
qualified_name_with_modules_from_root=qualified_name_with_modules_from_root,
function_name=scope,
test_class=test_func.test_class or "",
test_function=scope_test_function,
test_type=test_func.test_type,
line_number=name.line,
col_number=name.column,
)

if test_func.test_type == TestType.REPLAY_TEST:
num_discovered_replay_tests += 1

Expand All @@ -707,4 +748,6 @@ def process_test_files(

progress.advance(task_id)

tests_cache.close()

return dict(function_to_test_map), num_discovered_tests, num_discovered_replay_tests
17 changes: 8 additions & 9 deletions codeflash/optimization/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,15 +239,14 @@ def discover_tests(
from codeflash.discovery.discover_unit_tests import discover_unit_tests

console.rule()
with progress_bar("Discovering existing function tests..."):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the progress_bar will also show the loading text animation in the extension, I think we should keep it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Screenshot 2025-10-09 at 3 41 12 PM there is still a progress bar, I just moved it elsewhere since I was actually seeing 2 progress bar artifacts

the new progress bar is implemented in codeflash/discovery/discover_unit_tests.py::process_test_files:577

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Screen.Recording.2025-10-09.at.3.43.36.PM.mov

this is the artifact

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Screen.Recording.2025-10-09.at.3.45.08.PM.mov

no artifact

start_time = time.time()
function_to_tests, num_discovered_tests, num_discovered_replay_tests = discover_unit_tests(
self.test_cfg, file_to_funcs_to_optimize=file_to_funcs_to_optimize
)
console.rule()
logger.info(
f"Discovered {num_discovered_tests} existing unit tests and {num_discovered_replay_tests} replay tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}"
)
start_time = time.time()
function_to_tests, num_discovered_tests, num_discovered_replay_tests = discover_unit_tests(
self.test_cfg, file_to_funcs_to_optimize=file_to_funcs_to_optimize
)
console.rule()
logger.info(
f"Discovered {num_discovered_tests} existing unit tests and {num_discovered_replay_tests} replay tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}"
)
console.rule()
ph("cli-optimize-discovered-tests", {"num_tests": num_discovered_tests})
return function_to_tests, num_discovered_tests
Expand Down
1 change: 1 addition & 0 deletions codeflash/verification/verification_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,4 @@ class TestConfig:
concolic_test_root_dir: Optional[Path] = None
pytest_cmd: str = "pytest"
benchmark_tests_root: Optional[Path] = None
use_cache: bool = True
4 changes: 2 additions & 2 deletions tests/scripts/end_to_end_test_init_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def run_test(expected_improvement_pct: int) -> bool:
file_path="remove_control_chars.py",
function_name="CharacterRemover.remove_control_characters",
test_framework="pytest",
min_improvement_x=0.3,
min_improvement_x=0.1,
coverage_expectations=[
CoverageExpectation(
function_name="CharacterRemover.remove_control_characters", expected_coverage=100.0, expected_lines=[14]
Expand All @@ -21,4 +21,4 @@ def run_test(expected_improvement_pct: int) -> bool:


if __name__ == "__main__":
exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 5))))
exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 10))))
35 changes: 35 additions & 0 deletions tests/test_unit_test_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from codeflash.verification.verification_utils import TestConfig


from pathlib import Path
from codeflash.discovery.discover_unit_tests import discover_unit_tests

def test_unit_test_discovery_pytest():
project_path = Path(__file__).parent.parent.resolve() / "code_to_optimize"
tests_path = project_path / "tests" / "pytest"
Expand Down Expand Up @@ -1327,3 +1330,35 @@ def test_target():
should_process = analyze_imports_in_test_file(test_file, target_functions)

assert should_process is False



def test_discover_unit_tests_caching():
tests_root = Path(__file__).parent.resolve() / "tests"
project_root_path = tests_root.parent.resolve()

test_config = TestConfig(
tests_root=tests_root,
project_root_path=project_root_path,
test_framework="pytest",
tests_project_rootdir=project_root_path,
use_cache=False,
)



non_cached_function_to_tests, non_cached_num_discovered_tests, non_cached_num_discovered_replay_tests = (
discover_unit_tests(test_config)
)
cache_config = TestConfig(
tests_root=tests_root,
project_root_path=project_root_path,
test_framework="pytest",
tests_project_rootdir=project_root_path,
use_cache=True,
)
tests, num_discovered_tests, num_discovered_replay_tests = discover_unit_tests(cache_config)

assert non_cached_num_discovered_tests == num_discovered_tests
assert non_cached_function_to_tests == tests
assert non_cached_num_discovered_replay_tests == num_discovered_replay_tests
Loading