Skip to content

Commit a85ed90

Browse files
authored
Merge pull request #194 from codeflash-ai/fix-benchmark-prompt
Fix Prompt for checkpoint being displayed correctly in DEBUG mode
2 parents e266864 + 1cf73c3 commit a85ed90

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

codeflash/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
88
from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO, ask_run_end_to_end_test
99
from codeflash.cli_cmds.console import paneled_text
10+
from codeflash.code_utils.checkpoint import ask_should_use_checkpoint_get_functions
1011
from codeflash.code_utils.config_parser import parse_config_file
1112
from codeflash.optimization import optimizer
1213
from codeflash.telemetry import posthog_cf
@@ -35,6 +36,7 @@ def main() -> None:
3536
ask_run_end_to_end_test(args)
3637
else:
3738
args = process_pyproject_config(args)
39+
args.previous_checkpoint_functions = ask_should_use_checkpoint_get_functions(args)
3840
init_sentry(not args.disable_telemetry, exclude_errors=True)
3941
posthog_cf.initialize_posthog(not args.disable_telemetry)
4042
optimizer.run_with_args(args)

codeflash/optimization/optimizer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table
1717
from codeflash.cli_cmds.console import console, logger, progress_bar
1818
from codeflash.code_utils import env_utils
19-
from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint, ask_should_use_checkpoint_get_functions
19+
from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint
2020
from codeflash.code_utils.code_replacer import normalize_code, normalize_node
2121
from codeflash.code_utils.code_utils import cleanup_paths, get_run_tmp_file
2222
from codeflash.code_utils.static_analysis import analyze_imported_modules, get_first_top_level_function_or_method_ast
@@ -85,7 +85,6 @@ def run(self) -> None:
8585
function_optimizer = None
8686
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]
8787
num_optimizable_functions: int
88-
previous_checkpoint_functions = ask_should_use_checkpoint_get_functions(self.args)
8988
# discover functions
9089
(file_to_funcs_to_optimize, num_optimizable_functions) = get_functions_to_optimize(
9190
optimize_all=self.args.all,
@@ -96,7 +95,7 @@ def run(self) -> None:
9695
ignore_paths=self.args.ignore_paths,
9796
project_root=self.args.project_root,
9897
module_root=self.args.module_root,
99-
previous_checkpoint_functions=previous_checkpoint_functions,
98+
previous_checkpoint_functions=self.args.previous_checkpoint_functions,
10099
)
101100
function_benchmark_timings: dict[str, dict[BenchmarkKey, int]] = {}
102101
total_benchmark_timings: dict[BenchmarkKey, int] = {}

0 commit comments

Comments
 (0)