diff --git a/codeflash/code_utils/compat.py b/codeflash/code_utils/compat.py index b62f93d4c..eb4e5b561 100644 --- a/codeflash/code_utils/compat.py +++ b/codeflash/code_utils/compat.py @@ -1,21 +1,47 @@ import os import sys +import tempfile from pathlib import Path +from typing import TYPE_CHECKING from platformdirs import user_config_dir -# os-independent newline -# important for any user-facing output or files we write -# make sure to use this in f-strings e.g. f"some string{LF}" -# you can use "[^f]\".*\{LF\}\" to find any lines in your code that use this without the f-string -LF: str = os.linesep +if TYPE_CHECKING: + codeflash_temp_dir: Path + codeflash_cache_dir: Path + codeflash_cache_db: Path -SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix() +class Compat: + # os-independent newline + LF: str = os.linesep -IS_POSIX = os.name != "nt" + SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix() + IS_POSIX: bool = os.name != "nt" -codeflash_cache_dir = Path(user_config_dir(appname="codeflash", appauthor="codeflash-ai", ensure_exists=True)) + @property + def codeflash_cache_dir(self) -> Path: + return Path(user_config_dir(appname="codeflash", appauthor="codeflash-ai", ensure_exists=True)) -codeflash_cache_db = codeflash_cache_dir / "codeflash_cache.db" + @property + def codeflash_temp_dir(self) -> Path: + temp_dir = Path(tempfile.gettempdir()) / "codeflash" + if not temp_dir.exists(): + temp_dir.mkdir(parents=True, exist_ok=True) + return temp_dir + + @property + def codeflash_cache_db(self) -> Path: + return self.codeflash_cache_dir / "codeflash_cache.db" + + +_compat = Compat() + + +codeflash_temp_dir = _compat.codeflash_temp_dir +codeflash_cache_dir = _compat.codeflash_cache_dir +codeflash_cache_db = _compat.codeflash_cache_db +LF = _compat.LF +SAFE_SYS_EXECUTABLE = _compat.SAFE_SYS_EXECUTABLE +IS_POSIX = _compat.IS_POSIX diff --git a/tests/test_function_discovery.py b/tests/test_function_discovery.py index da40121bc..38a616cd4 100644 --- a/tests/test_function_discovery.py +++ b/tests/test_function_discovery.py @@ -1,13 +1,18 @@ import tempfile from pathlib import Path +import os +import unittest.mock from codeflash.discovery.functions_to_optimize import ( filter_files_optimized, find_all_functions_in_file, get_functions_to_optimize, inspect_top_level_functions_or_methods, + filter_functions, + get_all_files_and_functions ) from codeflash.verification.verification_utils import TestConfig +from codeflash.code_utils.compat import codeflash_temp_dir def test_function_eligible_for_optimization() -> None: @@ -313,3 +318,241 @@ def test_filter_files_optimized(): assert filter_files_optimized(file_path_same_level, tests_root, ignore_paths, module_root) assert filter_files_optimized(file_path_different_level, tests_root, ignore_paths, module_root) assert not filter_files_optimized(file_path_above_level, tests_root, ignore_paths, module_root) + +def test_filter_functions(): + with tempfile.TemporaryDirectory() as temp_dir_str: + temp_dir = Path(temp_dir_str) + + # Create a test file in the temporary directory + test_file_path = temp_dir.joinpath("test_get_functions_to_optimize.py") + with test_file_path.open("w") as f: + f.write( +""" +import copy + +def propagate_attributes( + nodes: dict[str, dict], edges: list[dict], source_node_id: str, attribute: str +) -> dict[str, dict]: + modified_nodes = copy.deepcopy(nodes) + + # Build an adjacency list for faster traversal + adjacency = {} + for edge in edges: + src = edge["source"] + tgt = edge["target"] + if src not in adjacency: + adjacency[src] = [] + adjacency[src].append(tgt) + + # Track visited nodes to avoid cycles + visited = set() + + def traverse(node_id): + if node_id in visited: + return + visited.add(node_id) + + # Propagate attribute from source node + if ( + node_id != source_node_id + and source_node_id in modified_nodes + and attribute in modified_nodes[source_node_id] + ): + if node_id in modified_nodes: + modified_nodes[node_id][attribute] = modified_nodes[source_node_id][ + attribute + ] + + # Continue propagation to neighbors + for neighbor in adjacency.get(node_id, []): + traverse(neighbor) + + traverse(source_node_id) + return modified_nodes + +def vanilla_function(): + return "This is a vanilla function." + +def not_in_checkpoint_function(): + return "This function is not in the checkpoint." +""" + ) + + + discovered = find_all_functions_in_file(test_file_path) + modified_functions = {test_file_path: discovered[test_file_path]} + filtered, count = filter_functions( + modified_functions, + tests_root=Path("tests"), + ignore_paths=[], + project_root=temp_dir, + module_root=temp_dir, + ) + function_names = [fn.function_name for fn in filtered.get(test_file_path, [])] + assert "propagate_attributes" in function_names + assert count == 3 + + # Create a tests directory inside our temp directory + tests_root_dir = temp_dir.joinpath("tests") + tests_root_dir.mkdir(exist_ok=True) + + test_file_path = tests_root_dir.joinpath("test_functions.py") + with test_file_path.open("w") as f: + f.write( +""" +def test_function_in_tests_dir(): + return "This function is in a test directory and should be filtered out." +""" + ) + + discovered_test_file = find_all_functions_in_file(test_file_path) + modified_functions_test = {test_file_path: discovered_test_file.get(test_file_path, [])} + + filtered_test_file, count_test_file = filter_functions( + modified_functions_test, + tests_root=tests_root_dir, + ignore_paths=[], + project_root=temp_dir, + module_root=temp_dir, + ) + + assert not filtered_test_file + assert count_test_file == 0 + + # Test ignored directory + ignored_dir = temp_dir.joinpath("ignored_dir") + ignored_dir.mkdir(exist_ok=True) + ignored_file_path = ignored_dir.joinpath("ignored_file.py") + with ignored_file_path.open("w") as f: + f.write("def ignored_func(): return 1") + + discovered_ignored = find_all_functions_in_file(ignored_file_path) + modified_functions_ignored = {ignored_file_path: discovered_ignored.get(ignored_file_path, [])} + + filtered_ignored, count_ignored = filter_functions( + modified_functions_ignored, + tests_root=Path("tests"), + ignore_paths=[ignored_dir], + project_root=temp_dir, + module_root=temp_dir, + ) + assert not filtered_ignored + assert count_ignored == 0 + + # Test submodule paths + with unittest.mock.patch("codeflash.discovery.functions_to_optimize.ignored_submodule_paths", + return_value=[str(temp_dir.joinpath("submodule_dir"))]): + submodule_dir = temp_dir.joinpath("submodule_dir") + submodule_dir.mkdir(exist_ok=True) + submodule_file_path = submodule_dir.joinpath("submodule_file.py") + with submodule_file_path.open("w") as f: + f.write("def submodule_func(): return 1") + + discovered_submodule = find_all_functions_in_file(submodule_file_path) + modified_functions_submodule = {submodule_file_path: discovered_submodule.get(submodule_file_path, [])} + + filtered_submodule, count_submodule = filter_functions( + modified_functions_submodule, + tests_root=Path("tests"), + ignore_paths=[], + project_root=temp_dir, + module_root=temp_dir, + ) + assert not filtered_submodule + assert count_submodule == 0 + + # Test site packages + with unittest.mock.patch("codeflash.discovery.functions_to_optimize.path_belongs_to_site_packages", + return_value=True): + site_package_file_path = temp_dir.joinpath("site_package_file.py") + with site_package_file_path.open("w") as f: + f.write("def site_package_func(): return 1") + + discovered_site_package = find_all_functions_in_file(site_package_file_path) + modified_functions_site_package = {site_package_file_path: discovered_site_package.get(site_package_file_path, [])} + + filtered_site_package, count_site_package = filter_functions( + modified_functions_site_package, + tests_root=Path("tests"), + ignore_paths=[], + project_root=temp_dir, + module_root=temp_dir, + ) + assert not filtered_site_package + assert count_site_package == 0 + + # Test outside module root + parent_dir = temp_dir.parent + outside_module_root_path = parent_dir.joinpath("outside_module_root_file.py") + try: + with outside_module_root_path.open("w") as f: + f.write("def func_outside_module_root(): return 1") + + discovered_outside_module = find_all_functions_in_file(outside_module_root_path) + modified_functions_outside_module = {outside_module_root_path: discovered_outside_module.get(outside_module_root_path, [])} + + filtered_outside_module, count_outside_module = filter_functions( + modified_functions_outside_module, + tests_root=Path("tests"), + ignore_paths=[], + project_root=temp_dir, + module_root=temp_dir, + ) + assert not filtered_outside_module + assert count_outside_module == 0 + finally: + outside_module_root_path.unlink(missing_ok=True) + + # Test invalid module name + invalid_module_file_path = temp_dir.joinpath("invalid-module-name.py") + with invalid_module_file_path.open("w") as f: + f.write("def func_in_invalid_module(): return 1") + + discovered_invalid_module = find_all_functions_in_file(invalid_module_file_path) + modified_functions_invalid_module = {invalid_module_file_path: discovered_invalid_module.get(invalid_module_file_path, [])} + + filtered_invalid_module, count_invalid_module = filter_functions( + modified_functions_invalid_module, + tests_root=Path("tests"), + ignore_paths=[], + project_root=temp_dir, + module_root=temp_dir, + ) + assert not filtered_invalid_module + assert count_invalid_module == 0 + + original_file_path = temp_dir.joinpath("test_get_functions_to_optimize.py") + with unittest.mock.patch("codeflash.discovery.functions_to_optimize.get_blocklisted_functions", + return_value={original_file_path.name: {"propagate_attributes", "other_blocklisted_function"}}): + filtered_funcs, count = filter_functions( + modified_functions, + tests_root=Path("tests"), + ignore_paths=[], + project_root=temp_dir, + module_root=temp_dir, + ) + assert "propagate_attributes" not in [fn.function_name for fn in filtered_funcs.get(original_file_path, [])] + assert count == 2 + + module_name = "test_get_functions_to_optimize" + qualified_name_for_checkpoint = f"{module_name}.propagate_attributes" + other_qualified_name_for_checkpoint = f"{module_name}.vanilla_function" + + with unittest.mock.patch("codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}): + filtered_checkpoint, count_checkpoint = filter_functions( + modified_functions, + tests_root=Path("tests"), + ignore_paths=[], + project_root=temp_dir, + module_root=temp_dir, + previous_checkpoint_functions={qualified_name_for_checkpoint: {"status": "optimized"}, other_qualified_name_for_checkpoint: {}} + ) + assert filtered_checkpoint.get(original_file_path) + assert count_checkpoint == 1 + + remaining_functions = [fn.function_name for fn in filtered_checkpoint.get(original_file_path, [])] + assert "not_in_checkpoint_function" in remaining_functions + assert "propagate_attributes" not in remaining_functions + assert "vanilla_function" not in remaining_functions + files_and_funcs = get_all_files_and_functions(module_root_path=temp_dir) + assert len(files_and_funcs) == 6 \ No newline at end of file