diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index 1bc4d66bb..82e122d0e 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -304,7 +304,9 @@ def get_all_function_names(code: str) -> tuple[bool, list[str]]: return True, function_names -def get_run_tmp_file(file_path: Path) -> Path: +def get_run_tmp_file(file_path: Path | str) -> Path: + if isinstance(file_path, str): + file_path = Path(file_path) if not hasattr(get_run_tmp_file, "tmpdir"): get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_") return Path(get_run_tmp_file.tmpdir.name) / file_path diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 81ffdfa44..398efe461 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -67,13 +67,39 @@ class TestFunction: class TestsCache: - def __init__(self) -> None: + SCHEMA_VERSION = 1 # Increment this when schema changes + + def __init__(self, project_root_path: str | Path) -> None: + self.project_root_path = Path(project_root_path).resolve().as_posix() self.connection = sqlite3.connect(codeflash_cache_db) self.cur = self.connection.cursor() + self.cur.execute( + """ + CREATE TABLE IF NOT EXISTS schema_version( + version INTEGER PRIMARY KEY + ) + """ + ) + + self.cur.execute("SELECT version FROM schema_version") + result = self.cur.fetchone() + current_version = result[0] if result else None + + if current_version != self.SCHEMA_VERSION: + logger.debug( + f"Schema version mismatch (current: {current_version}, expected: {self.SCHEMA_VERSION}). Recreating tables." + ) + self.cur.execute("DROP TABLE IF EXISTS discovered_tests") + self.cur.execute("DROP INDEX IF EXISTS idx_discovered_tests_project_file_path_hash") + self.cur.execute("DELETE FROM schema_version") + self.cur.execute("INSERT INTO schema_version (version) VALUES (?)", (self.SCHEMA_VERSION,)) + self.connection.commit() + self.cur.execute( """ CREATE TABLE IF NOT EXISTS discovered_tests( + project_root_path TEXT, file_path TEXT, file_hash TEXT, qualified_name_with_modules_from_root TEXT, @@ -88,11 +114,12 @@ def __init__(self) -> None: ) self.cur.execute( """ - CREATE INDEX IF NOT EXISTS idx_discovered_tests_file_path_hash - ON discovered_tests (file_path, file_hash) + CREATE INDEX IF NOT EXISTS idx_discovered_tests_project_file_path_hash + ON discovered_tests (project_root_path, file_path, file_hash) """ ) - self._memory_cache = {} + + self.memory_cache = {} def insert_test( self, @@ -108,8 +135,9 @@ def insert_test( ) -> None: test_type_value = test_type.value if hasattr(test_type, "value") else test_type self.cur.execute( - "INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + "INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", ( + self.project_root_path, file_path, file_hash, qualified_name_with_modules_from_root, @@ -123,32 +151,48 @@ def insert_test( ) self.connection.commit() - def get_tests_for_file(self, file_path: str, file_hash: str) -> list[FunctionCalledInTest]: - cache_key = (file_path, file_hash) - if cache_key in self._memory_cache: - return self._memory_cache[cache_key] - self.cur.execute("SELECT * FROM discovered_tests WHERE file_path = ? AND file_hash = ?", (file_path, file_hash)) - result = [ - FunctionCalledInTest( + def get_function_to_test_map_for_file( + self, file_path: str, file_hash: str + ) -> dict[str, set[FunctionCalledInTest]] | None: + cache_key = (self.project_root_path, file_path, file_hash) + if cache_key in self.memory_cache: + return self.memory_cache[cache_key] + + self.cur.execute( + "SELECT * FROM discovered_tests WHERE project_root_path = ? AND file_path = ? AND file_hash = ?", + (self.project_root_path, file_path, file_hash), + ) + rows = self.cur.fetchall() + if not rows: + return None + + function_to_test_map = defaultdict(set) + + for row in rows: + qualified_name_with_modules_from_root = row[3] + function_called_in_test = FunctionCalledInTest( tests_in_file=TestsInFile( - test_file=Path(row[0]), test_class=row[4], test_function=row[5], test_type=TestType(int(row[6])) + test_file=Path(row[1]), test_class=row[5], test_function=row[6], test_type=TestType(int(row[7])) ), - position=CodePosition(line_no=row[7], col_no=row[8]), + position=CodePosition(line_no=row[8], col_no=row[9]), ) - for row in self.cur.fetchall() - ] - self._memory_cache[cache_key] = result + function_to_test_map[qualified_name_with_modules_from_root].add(function_called_in_test) + + result = dict(function_to_test_map) + self.memory_cache[cache_key] = result return result @staticmethod - def compute_file_hash(path: str) -> str: + def compute_file_hash(path: Path) -> str: h = hashlib.sha256(usedforsecurity=False) - with Path(path).open("rb") as f: + with path.open("rb", buffering=0) as f: + buf = bytearray(8192) + mv = memoryview(buf) while True: - chunk = f.read(8192) - if not chunk: + n = f.readinto(mv) + if n == 0: break - h.update(chunk) + h.update(mv[:n]) return h.hexdigest() def close(self) -> None: @@ -394,7 +438,7 @@ def discover_tests_pytest( cfg: TestConfig, discover_only_these_tests: list[Path] | None = None, functions_to_optimize: list[FunctionToOptimize] | None = None, -) -> tuple[dict[str, set[FunctionCalledInTest]], int]: +) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]: tests_root = cfg.tests_root project_root = cfg.project_root_path @@ -432,9 +476,11 @@ def discover_tests_pytest( f"Failed to collect tests. Pytest Exit code: {exitcode}={PytestExitCode(exitcode).name}\n {error_section}" ) if "ModuleNotFoundError" in result.stdout: - match = ImportErrorPattern.search(result.stdout).group() - panel = Panel(Text.from_markup(f"⚠️ {match} ", style="bold red"), expand=False) - console.print(panel) + match = ImportErrorPattern.search(result.stdout) + if match: + error_message = match.group() + panel = Panel(Text.from_markup(f"⚠️ {error_message} ", style="bold red"), expand=False) + console.print(panel) elif 0 <= exitcode <= 5: logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={PytestExitCode(exitcode).name}") @@ -469,13 +515,13 @@ def discover_tests_pytest( def discover_tests_unittest( cfg: TestConfig, - discover_only_these_tests: list[str] | None = None, + discover_only_these_tests: list[Path] | None = None, functions_to_optimize: list[FunctionToOptimize] | None = None, -) -> tuple[dict[str, set[FunctionCalledInTest]], int]: +) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]: tests_root: Path = cfg.tests_root loader: unittest.TestLoader = unittest.TestLoader() tests: unittest.TestSuite = loader.discover(str(tests_root)) - file_to_test_map: defaultdict[str, list[TestsInFile]] = defaultdict(list) + file_to_test_map: defaultdict[Path, list[TestsInFile]] = defaultdict(list) def get_test_details(_test: unittest.TestCase) -> TestsInFile | None: _test_function, _test_module, _test_suite_name = ( @@ -487,7 +533,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None: _test_module_path = Path(_test_module.replace(".", os.sep)).with_suffix(".py") _test_module_path = tests_root / _test_module_path if not _test_module_path.exists() or ( - discover_only_these_tests and str(_test_module_path) not in discover_only_these_tests + discover_only_these_tests and _test_module_path not in discover_only_these_tests ): return None if "__replay_test" in str(_test_module_path): @@ -497,10 +543,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None: else: test_type = TestType.EXISTING_UNIT_TEST return TestsInFile( - test_file=str(_test_module_path), - test_function=_test_function, - test_type=test_type, - test_class=_test_suite_name, + test_file=_test_module_path, test_function=_test_function, test_type=test_type, test_class=_test_suite_name ) for _test_suite in tests._tests: @@ -518,18 +561,18 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None: continue details = get_test_details(test_2) if details is not None: - file_to_test_map[str(details.test_file)].append(details) + file_to_test_map[details.test_file].append(details) else: details = get_test_details(test) if details is not None: - file_to_test_map[str(details.test_file)].append(details) + file_to_test_map[details.test_file].append(details) return process_test_files(file_to_test_map, cfg, functions_to_optimize) def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | None]: - function_name = function_name.split("_") - if len(function_name) > 1 and function_name[-1].isdigit(): - return True, "_".join(function_name[:-1]), function_name[-1] + function_parts = function_name.split("_") + if len(function_parts) > 1 and function_parts[-1].isdigit(): + return True, "_".join(function_parts[:-1]), function_parts[-1] return False, function_name, None @@ -538,7 +581,7 @@ def process_test_files( file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig, functions_to_optimize: list[FunctionToOptimize] | None = None, -) -> tuple[dict[str, set[FunctionCalledInTest]], int]: +) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]: import jedi project_root_path = cfg.project_root_path @@ -553,29 +596,39 @@ def process_test_files( num_discovered_replay_tests = 0 jedi_project = jedi.Project(path=project_root_path) + tests_cache = TestsCache(project_root_path) + logger.info("!lsp|Discovering tests and processing unit tests") with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as ( progress, task_id, ): for test_file, functions in file_to_test_map.items(): + file_hash = TestsCache.compute_file_hash(test_file) + + cached_function_to_test_map = tests_cache.get_function_to_test_map_for_file(str(test_file), file_hash) + + if cfg.use_cache and cached_function_to_test_map: + for qualified_name, test_set in cached_function_to_test_map.items(): + function_to_test_map[qualified_name].update(test_set) + + for function_called_in_test in test_set: + if function_called_in_test.tests_in_file.test_type == TestType.REPLAY_TEST: + num_discovered_replay_tests += 1 + num_discovered_tests += 1 + + progress.advance(task_id) + continue try: script = jedi.Script(path=test_file, project=jedi_project) test_functions = set() - # Single call to get all names with references and definitions - all_names = script.get_names(all_scopes=True, references=True, definitions=True) + all_names = script.get_names(all_scopes=True, references=True) + all_defs = script.get_names(all_scopes=True, definitions=True) + all_names_top = script.get_names(all_scopes=True) - # Filter once and create lookup dictionaries - top_level_functions = {} - top_level_classes = {} - all_defs = [] + top_level_functions = {name.name: name for name in all_names_top if name.type == "function"} + top_level_classes = {name.name: name for name in all_names_top if name.type == "class"} - for name in all_names: - if name.type == "function": - top_level_functions[name.name] = name - all_defs.append(name) - elif name.type == "class": - top_level_classes[name.name] = name except Exception as e: logger.debug(f"Failed to get jedi script for {test_file}: {e}") progress.advance(task_id) @@ -697,6 +750,18 @@ def process_test_files( position=CodePosition(line_no=name.line, col_no=name.column), ) ) + tests_cache.insert_test( + file_path=str(test_file), + file_hash=file_hash, + qualified_name_with_modules_from_root=qualified_name_with_modules_from_root, + function_name=scope, + test_class=test_func.test_class or "", + test_function=scope_test_function, + test_type=test_func.test_type, + line_number=name.line, + col_number=name.column, + ) + if test_func.test_type == TestType.REPLAY_TEST: num_discovered_replay_tests += 1 @@ -707,4 +772,6 @@ def process_test_files( progress.advance(task_id) + tests_cache.close() + return dict(function_to_test_map), num_discovered_tests, num_discovered_replay_tests diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 9d23b9abb..c0e0b014b 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -239,15 +239,15 @@ def discover_tests( from codeflash.discovery.discover_unit_tests import discover_unit_tests console.rule() - with progress_bar("Discovering existing function tests..."): - start_time = time.time() - function_to_tests, num_discovered_tests, num_discovered_replay_tests = discover_unit_tests( - self.test_cfg, file_to_funcs_to_optimize=file_to_funcs_to_optimize - ) - console.rule() - logger.info( - f"Discovered {num_discovered_tests} existing unit tests and {num_discovered_replay_tests} replay tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}" - ) + start_time = time.time() + logger.info("lsp,loading|Discovering existing function tests...") + function_to_tests, num_discovered_tests, num_discovered_replay_tests = discover_unit_tests( + self.test_cfg, file_to_funcs_to_optimize=file_to_funcs_to_optimize + ) + console.rule() + logger.info( + f"Discovered {num_discovered_tests} existing unit tests and {num_discovered_replay_tests} replay tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}" + ) console.rule() ph("cli-optimize-discovered-tests", {"num_tests": num_discovered_tests}) return function_to_tests, num_discovered_tests diff --git a/codeflash/verification/verification_utils.py b/codeflash/verification/verification_utils.py index 43cb78770..3641de340 100644 --- a/codeflash/verification/verification_utils.py +++ b/codeflash/verification/verification_utils.py @@ -76,3 +76,4 @@ class TestConfig: concolic_test_root_dir: Optional[Path] = None pytest_cmd: str = "pytest" benchmark_tests_root: Optional[Path] = None + use_cache: bool = True diff --git a/tests/scripts/end_to_end_test_init_optimization.py b/tests/scripts/end_to_end_test_init_optimization.py index 30fc930c5..ef6afb659 100644 --- a/tests/scripts/end_to_end_test_init_optimization.py +++ b/tests/scripts/end_to_end_test_init_optimization.py @@ -21,4 +21,4 @@ def run_test(expected_improvement_pct: int) -> bool: if __name__ == "__main__": - exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 5)))) + exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 10)))) diff --git a/tests/test_unit_test_discovery.py b/tests/test_unit_test_discovery.py index 673b08ab4..5af66ebc4 100644 --- a/tests/test_unit_test_discovery.py +++ b/tests/test_unit_test_discovery.py @@ -12,6 +12,9 @@ from codeflash.verification.verification_utils import TestConfig +from pathlib import Path +from codeflash.discovery.discover_unit_tests import discover_unit_tests + def test_unit_test_discovery_pytest(): project_path = Path(__file__).parent.parent.resolve() / "code_to_optimize" tests_path = project_path / "tests" / "pytest" @@ -1327,3 +1330,35 @@ def test_target(): should_process = analyze_imports_in_test_file(test_file, target_functions) assert should_process is False + + + +def test_discover_unit_tests_caching(): + tests_root = Path(__file__).parent.resolve() / "tests" + project_root_path = tests_root.parent.resolve() + + test_config = TestConfig( + tests_root=tests_root, + project_root_path=project_root_path, + test_framework="pytest", + tests_project_rootdir=project_root_path, + use_cache=False, + ) + + + + non_cached_function_to_tests, non_cached_num_discovered_tests, non_cached_num_discovered_replay_tests = ( + discover_unit_tests(test_config) + ) + cache_config = TestConfig( + tests_root=tests_root, + project_root_path=project_root_path, + test_framework="pytest", + tests_project_rootdir=project_root_path, + use_cache=True, + ) + tests, num_discovered_tests, num_discovered_replay_tests = discover_unit_tests(cache_config) + + assert non_cached_num_discovered_tests == num_discovered_tests + assert non_cached_function_to_tests == tests + assert non_cached_num_discovered_replay_tests == num_discovered_replay_tests \ No newline at end of file