diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index 82a5b9791..22d90903f 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -220,3 +220,7 @@ def exit_with_message(message: str, *, error_on_exit: bool = False) -> None: paneled_text(message, panel_args={"style": "red"}) sys.exit(1 if error_on_exit else 0) + + +async def dummy_async_function() -> None: + """Provide a dummy async function for testing purposes.""" diff --git a/codeflash/code_utils/static_analysis.py b/codeflash/code_utils/static_analysis.py index dbddb59f5..8f492bc7e 100644 --- a/codeflash/code_utils/static_analysis.py +++ b/codeflash/code_utils/static_analysis.py @@ -7,6 +7,8 @@ from pydantic import BaseModel, ConfigDict, field_validator +from codeflash.models.models import FunctionParent + if TYPE_CHECKING: from codeflash.models.models import FunctionParent @@ -139,14 +141,19 @@ def get_first_top_level_function_or_method_ast( def function_kind(node: ast.FunctionDef | ast.AsyncFunctionDef, parents: list[FunctionParent]) -> FunctionKind | None: - if not parents or parents[0].type in ["FunctionDef", "AsyncFunctionDef"]: + if not parents: + return FunctionKind.FUNCTION + parent_type = parents[0].type + if parent_type in {"FunctionDef", "AsyncFunctionDef"}: return FunctionKind.FUNCTION - if parents[0].type == "ClassDef": + if parent_type == "ClassDef": + # Search for known decorators efficiently for decorator in node.decorator_list: if isinstance(decorator, ast.Name): - if decorator.id == "classmethod": + dec_id = decorator.id + if dec_id == "classmethod": return FunctionKind.CLASS_METHOD - if decorator.id == "staticmethod": + if dec_id == "staticmethod": return FunctionKind.STATIC_METHOD return FunctionKind.INSTANCE_METHOD return None diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 792a9fcff..d0fff4aed 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -147,6 +147,15 @@ def qualified_name(self) -> str: def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str: return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}" + @property + def server_info(self) -> dict[str, str | int]: + return { + "file_path": str(self.file_path), + "function_name": self.function_name, + "starting_line": self.starting_line, + "ending_line": self.ending_line, + } + def get_functions_to_optimize( optimize_all: str | None, diff --git a/codeflash/lsp/__init__.py b/codeflash/lsp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/codeflash/lsp/beta.py b/codeflash/lsp/beta.py new file mode 100644 index 000000000..f19b34fca --- /dev/null +++ b/codeflash/lsp/beta.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING + +from pygls import uris + +from codeflash.either import is_successful +from codeflash.lsp.server import CodeflashLanguageServer, CodeflashLanguageServerProtocol + +if TYPE_CHECKING: + from lsprotocol import types + + +@dataclass +class OptimizableFunctionsParams: + textDocument: types.TextDocumentIdentifier # noqa: N815 + + +@dataclass +class OptimizeFunctionParams: + textDocument: types.TextDocumentIdentifier # noqa: N815 + functionName: str # noqa: N815 + + +server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol) + + +@server.feature("getOptimizableFunctions") +def get_optimizable_functions( + server: CodeflashLanguageServer, params: OptimizableFunctionsParams +) -> dict[str, list[str]]: + file_path = Path(uris.to_fs_path(params.textDocument.uri)) + server.optimizer.args.file = file_path + server.optimizer.args.previous_checkpoint_functions = False + optimizable_funcs, _ = server.optimizer.get_optimizable_functions() + path_to_qualified_names = {} + for path, functions in optimizable_funcs.items(): + path_to_qualified_names[path.as_posix()] = [func.qualified_name for func in functions] + return path_to_qualified_names + + +@server.feature("optimizeFunction") +def optimize_function(server: CodeflashLanguageServer, params: OptimizeFunctionParams) -> dict[str, str]: + file_path = Path(uris.to_fs_path(params.textDocument.uri)) + server.optimizer.args.function = params.functionName + server.optimizer.args.file = file_path + optimizable_funcs, _ = server.optimizer.get_optimizable_functions() + if not optimizable_funcs: + return {"functionName": params.functionName, "status": "not found", "args": None} + fto = optimizable_funcs.popitem()[1][0] + server.optimizer.current_function_being_optimized = fto + return {"functionName": params.functionName, "status": "success", "info": fto.server_info} + + +@server.feature("second_step_in_optimize_function") +def second_step_in_optimize_function(server: CodeflashLanguageServer, params: OptimizeFunctionParams) -> dict[str, str]: + current_function = server.optimizer.current_function_being_optimized + + optimizable_funcs = {current_function.file_path: [current_function]} + + function_to_tests, num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs) + # mocking in order to get things going + return {"functionName": params.functionName, "status": "success", "generated_tests": str(num_discovered_tests)} + + +@server.feature("third_step_in_optimize_function") +def third_step_in_optimize_function(server: CodeflashLanguageServer, params: OptimizeFunctionParams) -> dict[str, str]: + current_function = server.optimizer.current_function_being_optimized + + module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path) + + validated_original_code, original_module_ast = module_prep_result + + function_optimizer = server.optimizer.create_function_optimizer( + current_function, + function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code, + original_module_ast=original_module_ast, + original_module_path=current_function.file_path, + ) + + server.optimizer.current_function_optimizer = function_optimizer + if not function_optimizer: + return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"} + + initialization_result = function_optimizer.can_be_optimized() + if not is_successful(initialization_result): + return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()} + + should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() + + return { + "functionName": params.functionName, + "status": "success", + "message": "Function can be optimized", + "extra": original_helper_code, + } + + +if __name__ == "__main__": + from codeflash.cli_cmds.console import console + + console.quiet = True + server.start_io() diff --git a/codeflash/lsp/server.py b/codeflash/lsp/server.py new file mode 100644 index 000000000..222a1318c --- /dev/null +++ b/codeflash/lsp/server.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from lsprotocol.types import INITIALIZE +from pygls import uris +from pygls.protocol import LanguageServerProtocol, lsp_method +from pygls.server import LanguageServer + +if TYPE_CHECKING: + from lsprotocol.types import InitializeParams, InitializeResult + + +class CodeflashLanguageServerProtocol(LanguageServerProtocol): + _server: CodeflashLanguageServer + + @lsp_method(INITIALIZE) + def lsp_initialize(self, params: InitializeParams) -> InitializeResult: + server = self._server + initialize_result: InitializeResult = super().lsp_initialize(params) + + workspace_uri = params.root_uri + if workspace_uri: + workspace_path = uris.to_fs_path(workspace_uri) + pyproject_toml_path = self._find_pyproject_toml(workspace_path) + if pyproject_toml_path: + server.initialize_optimizer(pyproject_toml_path) + server.show_message(f"Found pyproject.toml at: {pyproject_toml_path}") + else: + server.show_message("No pyproject.toml found in workspace.") + else: + server.show_message("No workspace URI provided.") + + return initialize_result + + def _find_pyproject_toml(self, workspace_path: str) -> Path | None: + workspace_path_obj = Path(workspace_path) + for file_path in workspace_path_obj.rglob("pyproject.toml"): + return file_path.resolve() + return None + + +class CodeflashLanguageServer(LanguageServer): + def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401 + super().__init__(*args, **kwargs) + self.optimizer = None + + def initialize_optimizer(self, config_file: Path) -> None: + from codeflash.cli_cmds.cli import parse_args, process_pyproject_config + from codeflash.optimization.optimizer import Optimizer + + args = parse_args() + args.config_file = config_file + args = process_pyproject_config(args) + self.optimizer = Optimizer(args) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index f6c7661b4..554b999fd 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -144,7 +144,7 @@ def __init__( self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {} self.replay_tests_dir = replay_tests_dir if replay_tests_dir else None - def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 + def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]: should_run_experiment = self.experiment_id is not None logger.debug(f"Function Trace ID: {self.function_trace_id}") ph("cli-optimize-function-start", {"function_trace_id": self.function_trace_id}) @@ -171,6 +171,15 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 ): return Failure("Function optimization previously attempted, skipping.") + return Success((should_run_experiment, code_context, original_helper_code)) + + def optimize_function(self) -> Result[BestOptimization, str]: + initialization_result = self.can_be_optimized() + if not is_successful(initialization_result): + return Failure(initialization_result.failure()) + + should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() + code_print(code_context.read_writable_code) generated_test_paths = [ get_test_file_path( diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 89737912d..3c2705414 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -11,7 +11,6 @@ from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient from codeflash.cli_cmds.console import console, logger, progress_bar from codeflash.code_utils import env_utils -from codeflash.code_utils.env_utils import get_pr_number from codeflash.either import is_successful from codeflash.models.models import ValidCode from codeflash.telemetry.posthog_cf import ph @@ -44,6 +43,148 @@ def __init__(self, args: Namespace) -> None: self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None self.replay_tests_dir = None self.functions_checkpoint: CodeflashRunCheckpoint | None = None + self.current_function_being_optimized: FunctionToOptimize | None = None # current only for the LSP + self.current_function_optimizer: FunctionOptimizer | None = None # current only for the LSP + + def run_benchmarks( + self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], num_optimizable_functions: int + ) -> tuple[dict[str, dict[BenchmarkKey, float]], dict[BenchmarkKey, float]]: + """Run benchmarks for the functions to optimize and collect timing information.""" + function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] = {} + total_benchmark_timings: dict[BenchmarkKey, float] = {} + + if not (hasattr(self.args, "benchmark") and self.args.benchmark and num_optimizable_functions > 0): + return function_benchmark_timings, total_benchmark_timings + + from codeflash.benchmarking.instrument_codeflash_trace import instrument_codeflash_trace_decorator + from codeflash.benchmarking.plugin.plugin import CodeFlashBenchmarkPlugin + from codeflash.benchmarking.replay_test import generate_replay_test + from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest + from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table + from codeflash.code_utils.env_utils import get_pr_number + + with progress_bar( + f"Running benchmarks in {self.args.benchmarks_root}", transient=True, revert_to_print=bool(get_pr_number()) + ): + # Insert decorator + file_path_to_source_code = defaultdict(str) + for file in file_to_funcs_to_optimize: + with file.open("r", encoding="utf8") as f: + file_path_to_source_code[file] = f.read() + try: + instrument_codeflash_trace_decorator(file_to_funcs_to_optimize) + trace_file = Path(self.args.benchmarks_root) / "benchmarks.trace" + if trace_file.exists(): + trace_file.unlink() + + self.replay_tests_dir = Path( + tempfile.mkdtemp(prefix="codeflash_replay_tests_", dir=self.args.benchmarks_root) + ) + trace_benchmarks_pytest( + self.args.benchmarks_root, self.args.tests_root, self.args.project_root, trace_file + ) # Run all tests that use pytest-benchmark + replay_count = generate_replay_test(trace_file, self.replay_tests_dir) + if replay_count == 0: + logger.info( + f"No valid benchmarks found in {self.args.benchmarks_root} for functions to optimize, continuing optimization" + ) + else: + function_benchmark_timings = CodeFlashBenchmarkPlugin.get_function_benchmark_timings(trace_file) + total_benchmark_timings = CodeFlashBenchmarkPlugin.get_benchmark_timings(trace_file) + function_to_results = validate_and_format_benchmark_table( + function_benchmark_timings, total_benchmark_timings + ) + print_benchmark_table(function_to_results) + except Exception as e: + logger.info(f"Error while tracing existing benchmarks: {e}") + logger.info("Information on existing benchmarks will not be available for this run.") + finally: + # Restore original source code + for file in file_path_to_source_code: + with file.open("w", encoding="utf8") as f: + f.write(file_path_to_source_code[file]) + + return function_benchmark_timings, total_benchmark_timings + + def prepare_module_for_optimization( + self, original_module_path: Path + ) -> tuple[dict[Path, ValidCode], ast.Module] | None: + from codeflash.code_utils.code_replacer import normalize_code, normalize_node + from codeflash.code_utils.static_analysis import analyze_imported_modules + + logger.info(f"Examining file {original_module_path!s}…") + console.rule() + + original_module_code: str = original_module_path.read_text(encoding="utf8") + try: + original_module_ast = ast.parse(original_module_code) + except SyntaxError as e: + logger.warning(f"Syntax error parsing code in {original_module_path}: {e}") + logger.info("Skipping optimization due to file error.") + return None + normalized_original_module_code = ast.unparse(normalize_node(original_module_ast)) + validated_original_code: dict[Path, ValidCode] = { + original_module_path: ValidCode( + source_code=original_module_code, normalized_code=normalized_original_module_code + ) + } + + imported_module_analyses = analyze_imported_modules( + original_module_code, original_module_path, self.args.project_root + ) + + has_syntax_error = False + for analysis in imported_module_analyses: + callee_original_code = analysis.file_path.read_text(encoding="utf8") + try: + normalized_callee_original_code = normalize_code(callee_original_code) + except SyntaxError as e: + logger.warning(f"Syntax error parsing code in callee module {analysis.file_path}: {e}") + logger.info("Skipping optimization due to helper file error.") + has_syntax_error = True + break + validated_original_code[analysis.file_path] = ValidCode( + source_code=callee_original_code, normalized_code=normalized_callee_original_code + ) + + if has_syntax_error: + return None + + return validated_original_code, original_module_ast + + def discover_tests( + self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] + ) -> tuple[dict[str, set[FunctionCalledInTest]], int]: + from codeflash.discovery.discover_unit_tests import discover_unit_tests + + console.rule() + start_time = time.time() + function_to_tests, num_discovered_tests = discover_unit_tests( + self.test_cfg, file_to_funcs_to_optimize=file_to_funcs_to_optimize + ) + console.rule() + logger.info( + f"Discovered {num_discovered_tests} existing unit tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}" + ) + console.rule() + ph("cli-optimize-discovered-tests", {"num_tests": num_discovered_tests}) + return function_to_tests, num_discovered_tests + + def get_optimizable_functions(self) -> tuple[dict[Path, list[FunctionToOptimize]], int]: + """Discover functions to optimize.""" + from codeflash.discovery.functions_to_optimize import get_functions_to_optimize + + return get_functions_to_optimize( + optimize_all=self.args.all, + replay_test=self.args.replay_test, + file=self.args.file, + only_get_this_function=self.args.function, + test_cfg=self.test_cfg, + ignore_paths=self.args.ignore_paths, + project_root=self.args.project_root, + module_root=self.args.module_root, + previous_checkpoint_functions=self.args.previous_checkpoint_functions, + ) def create_function_optimizer( self, @@ -53,9 +194,35 @@ def create_function_optimizer( function_to_optimize_source_code: str | None = "", function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None, total_benchmark_timings: dict[BenchmarkKey, float] | None = None, - ) -> FunctionOptimizer: + original_module_ast: ast.Module | None = None, + original_module_path: Path | None = None, + ) -> FunctionOptimizer | None: + from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast from codeflash.optimization.function_optimizer import FunctionOptimizer + if function_to_optimize_ast is None and original_module_ast is not None: + function_to_optimize_ast = get_first_top_level_function_or_method_ast( + function_to_optimize.function_name, function_to_optimize.parents, original_module_ast + ) + if function_to_optimize_ast is None: + logger.info( + f"Function {function_to_optimize.qualified_name} not found in {original_module_path}.\n" + f"Skipping optimization." + ) + return None + + qualified_name_w_module = function_to_optimize.qualified_name_with_modules_from_root(self.args.project_root) + + function_specific_timings = None + if ( + hasattr(self.args, "benchmark") + and self.args.benchmark + and function_benchmark_timings + and qualified_name_w_module in function_benchmark_timings + and total_benchmark_timings + ): + function_specific_timings = function_benchmark_timings[qualified_name_w_module] + return FunctionOptimizer( function_to_optimize=function_to_optimize, test_cfg=self.test_cfg, @@ -64,20 +231,14 @@ def create_function_optimizer( function_to_optimize_ast=function_to_optimize_ast, aiservice_client=self.aiservice_client, args=self.args, - function_benchmark_timings=function_benchmark_timings if function_benchmark_timings else None, - total_benchmark_timings=total_benchmark_timings if total_benchmark_timings else None, + function_benchmark_timings=function_specific_timings, + total_benchmark_timings=total_benchmark_timings if function_specific_timings else None, replay_tests_dir=self.replay_tests_dir, ) def run(self) -> None: from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint - from codeflash.code_utils.code_replacer import normalize_code, normalize_node - from codeflash.code_utils.static_analysis import ( - analyze_imported_modules, - get_first_top_level_function_or_method_ast, - ) - from codeflash.discovery.discover_unit_tests import discover_unit_tests - from codeflash.discovery.functions_to_optimize import get_functions_to_optimize + from codeflash.code_utils.code_utils import cleanup_paths ph("cli-optimize-run-start") logger.info("Running optimizer.") @@ -87,72 +248,12 @@ def run(self) -> None: if not env_utils.check_formatter_installed(self.args.formatter_cmds): return function_optimizer = None - file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] - num_optimizable_functions: int # discover functions - (file_to_funcs_to_optimize, num_optimizable_functions) = get_functions_to_optimize( - optimize_all=self.args.all, - replay_test=self.args.replay_test, - file=self.args.file, - only_get_this_function=self.args.function, - test_cfg=self.test_cfg, - ignore_paths=self.args.ignore_paths, - project_root=self.args.project_root, - module_root=self.args.module_root, - previous_checkpoint_functions=self.args.previous_checkpoint_functions, + file_to_funcs_to_optimize, num_optimizable_functions = self.get_optimizable_functions() + function_benchmark_timings, total_benchmark_timings = self.run_benchmarks( + file_to_funcs_to_optimize, num_optimizable_functions ) - function_benchmark_timings: dict[str, dict[BenchmarkKey, int]] = {} - total_benchmark_timings: dict[BenchmarkKey, int] = {} - if self.args.benchmark and num_optimizable_functions > 0: - from codeflash.benchmarking.instrument_codeflash_trace import instrument_codeflash_trace_decorator - from codeflash.benchmarking.plugin.plugin import CodeFlashBenchmarkPlugin - from codeflash.benchmarking.replay_test import generate_replay_test - from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest - from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table - - console.rule() - with progress_bar( - f"Running benchmarks in {self.args.benchmarks_root}", - transient=True, - revert_to_print=bool(get_pr_number()), - ): - # Insert decorator - file_path_to_source_code = defaultdict(str) - for file in file_to_funcs_to_optimize: - with file.open("r", encoding="utf8") as f: - file_path_to_source_code[file] = f.read() - try: - instrument_codeflash_trace_decorator(file_to_funcs_to_optimize) - trace_file = Path(self.args.benchmarks_root) / "benchmarks.trace" - if trace_file.exists(): - trace_file.unlink() - - self.replay_tests_dir = Path( - tempfile.mkdtemp(prefix="codeflash_replay_tests_", dir=self.args.benchmarks_root) - ) - trace_benchmarks_pytest( - self.args.benchmarks_root, self.args.tests_root, self.args.project_root, trace_file - ) # Run all tests that use pytest-benchmark - replay_count = generate_replay_test(trace_file, self.replay_tests_dir) - if replay_count == 0: - logger.info( - f"No valid benchmarks found in {self.args.benchmarks_root} for functions to optimize, continuing optimization" - ) - else: - function_benchmark_timings = CodeFlashBenchmarkPlugin.get_function_benchmark_timings(trace_file) - total_benchmark_timings = CodeFlashBenchmarkPlugin.get_benchmark_timings(trace_file) - function_to_results = validate_and_format_benchmark_table( - function_benchmark_timings, total_benchmark_timings - ) - print_benchmark_table(function_to_results) - except Exception as e: - logger.info(f"Error while tracing existing benchmarks: {e}") - logger.info("Information on existing benchmarks will not be available for this run.") - finally: - # Restore original source code - for file in file_path_to_source_code: - with file.open("w", encoding="utf8") as f: - f.write(file_path_to_source_code[file]) + optimizations_found: int = 0 function_iterator_count: int = 0 if self.args.test_framework == "pytest": @@ -165,58 +266,16 @@ def run(self) -> None: logger.info("No functions found to optimize. Exiting…") return - console.rule() - start_time = time.time() - function_to_tests, num_discovered_tests = discover_unit_tests( - self.test_cfg, file_to_funcs_to_optimize=file_to_funcs_to_optimize - ) - console.rule() - logger.info( - f"Discovered {num_discovered_tests} existing unit tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}" - ) - console.rule() - ph("cli-optimize-discovered-tests", {"num_tests": num_discovered_tests}) + function_to_tests, _ = self.discover_tests(file_to_funcs_to_optimize) if self.args.all: self.functions_checkpoint = CodeflashRunCheckpoint(self.args.module_root) for original_module_path in file_to_funcs_to_optimize: - logger.info(f"Examining file {original_module_path!s}…") - console.rule() - - original_module_code: str = original_module_path.read_text(encoding="utf8") - try: - original_module_ast = ast.parse(original_module_code) - except SyntaxError as e: - logger.warning(f"Syntax error parsing code in {original_module_path}: {e}") - logger.info("Skipping optimization due to file error.") + module_prep_result = self.prepare_module_for_optimization(original_module_path) + if module_prep_result is None: continue - normalized_original_module_code = ast.unparse(normalize_node(original_module_ast)) - validated_original_code: dict[Path, ValidCode] = { - original_module_path: ValidCode( - source_code=original_module_code, normalized_code=normalized_original_module_code - ) - } - imported_module_analyses = analyze_imported_modules( - original_module_code, original_module_path, self.args.project_root - ) - - has_syntax_error = False - for analysis in imported_module_analyses: - callee_original_code = analysis.file_path.read_text(encoding="utf8") - try: - normalized_callee_original_code = normalize_code(callee_original_code) - except SyntaxError as e: - logger.warning(f"Syntax error parsing code in callee module {analysis.file_path}: {e}") - logger.info("Skipping optimization due to helper file error.") - has_syntax_error = True - break - validated_original_code[analysis.file_path] = ValidCode( - source_code=callee_original_code, normalized_code=normalized_callee_original_code - ) - - if has_syntax_error: - continue + validated_original_code, original_module_ast = module_prep_result for function_to_optimize in file_to_funcs_to_optimize[original_module_path]: function_iterator_count += 1 @@ -225,40 +284,19 @@ def run(self) -> None: f"{function_to_optimize.qualified_name}" ) console.rule() - if not ( - function_to_optimize_ast := get_first_top_level_function_or_method_ast( - function_to_optimize.function_name, function_to_optimize.parents, original_module_ast - ) - ): - logger.info( - f"Function {function_to_optimize.qualified_name} not found in {original_module_path}.\n" - f"Skipping optimization." - ) - continue - qualified_name_w_module = function_to_optimize.qualified_name_with_modules_from_root( - self.args.project_root + + function_optimizer = self.create_function_optimizer( + function_to_optimize, + function_to_tests=function_to_tests, + function_to_optimize_source_code=validated_original_code[original_module_path].source_code, + function_benchmark_timings=function_benchmark_timings, + total_benchmark_timings=total_benchmark_timings, + original_module_ast=original_module_ast, + original_module_path=original_module_path, ) - if ( - self.args.benchmark - and function_benchmark_timings - and qualified_name_w_module in function_benchmark_timings - and total_benchmark_timings - ): - function_optimizer = self.create_function_optimizer( - function_to_optimize, - function_to_optimize_ast, - function_to_tests, - validated_original_code[original_module_path].source_code, - function_benchmark_timings[qualified_name_w_module], - total_benchmark_timings, - ) - else: - function_optimizer = self.create_function_optimizer( - function_to_optimize, - function_to_optimize_ast, - function_to_tests, - validated_original_code[original_module_path].source_code, - ) + + if function_optimizer is None: + continue best_optimization = function_optimizer.optimize_function() if self.functions_checkpoint: @@ -282,22 +320,14 @@ def run(self) -> None: if function_optimizer: function_optimizer.cleanup_generated_files() - self.cleanup_temporary_paths() - - def cleanup_temporary_paths(self) -> None: - from codeflash.code_utils.code_utils import cleanup_paths - - cleanup_paths([self.test_cfg.concolic_test_root_dir, self.replay_tests_dir]) + if self.test_cfg.concolic_test_root_dir: + cleanup_paths([self.test_cfg.concolic_test_root_dir]) def run_with_args(args: Namespace) -> None: - optimizer = None try: optimizer = Optimizer(args) optimizer.run() except KeyboardInterrupt: - logger.warning("Keyboard interrupt received. Cleaning up and exiting, please wait…") - if optimizer: - optimizer.cleanup_temporary_paths() - + logger.warning("Keyboard interrupt received. Exiting, please wait…") raise SystemExit from None