diff --git a/codeflash/main.py b/codeflash/main.py index 2ec4c614d..02b13d5aa 100644 --- a/codeflash/main.py +++ b/codeflash/main.py @@ -7,6 +7,7 @@ from codeflash.cli_cmds.cli import parse_args, process_pyproject_config from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO, ask_run_end_to_end_test from codeflash.cli_cmds.console import paneled_text +from codeflash.code_utils.checkpoint import ask_should_use_checkpoint_get_functions from codeflash.code_utils.config_parser import parse_config_file from codeflash.optimization import optimizer from codeflash.telemetry import posthog_cf @@ -35,6 +36,7 @@ def main() -> None: ask_run_end_to_end_test(args) else: args = process_pyproject_config(args) + args.previous_checkpoint_functions = ask_should_use_checkpoint_get_functions(args) init_sentry(not args.disable_telemetry, exclude_errors=True) posthog_cf.initialize_posthog(not args.disable_telemetry) optimizer.run_with_args(args) diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 10d21def5..de2cc1740 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -16,7 +16,7 @@ from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table from codeflash.cli_cmds.console import console, logger, progress_bar from codeflash.code_utils import env_utils -from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint, ask_should_use_checkpoint_get_functions +from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint from codeflash.code_utils.code_replacer import normalize_code, normalize_node from codeflash.code_utils.code_utils import cleanup_paths, get_run_tmp_file from codeflash.code_utils.static_analysis import analyze_imported_modules, get_first_top_level_function_or_method_ast @@ -85,7 +85,6 @@ def run(self) -> None: function_optimizer = None file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] num_optimizable_functions: int - previous_checkpoint_functions = ask_should_use_checkpoint_get_functions(self.args) # discover functions (file_to_funcs_to_optimize, num_optimizable_functions) = get_functions_to_optimize( optimize_all=self.args.all, @@ -96,7 +95,7 @@ def run(self) -> None: ignore_paths=self.args.ignore_paths, project_root=self.args.project_root, module_root=self.args.module_root, - previous_checkpoint_functions=previous_checkpoint_functions, + previous_checkpoint_functions=self.args.previous_checkpoint_functions, ) function_benchmark_timings: dict[str, dict[BenchmarkKey, int]] = {} total_benchmark_timings: dict[BenchmarkKey, int] = {}