From 4571b2d4901c43e19311123b21989542d4480e3a Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Sun, 20 Apr 2025 19:20:41 -0700 Subject: [PATCH] clean previous optimizations when using --file mode optimizations --- codeflash/optimization/function_optimizer.py | 48 ++++++++++++++------ 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 49958dc96..de95345c6 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -94,7 +94,7 @@ def __init__( function_benchmark_timings: dict[BenchmarkKey, int] | None = None, total_benchmark_timings: dict[BenchmarkKey, int] | None = None, args: Namespace | None = None, - replay_tests_dir: Path|None = None + replay_tests_dir: Path | None = None, ) -> None: self.project_root = test_cfg.project_root_path self.test_cfg = test_cfg @@ -273,7 +273,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: processed_benchmark_info = process_benchmark_data( replay_performance_gain=best_optimization.replay_performance_gain, fto_benchmark_timings=self.function_benchmark_timings, - total_benchmark_timings=self.total_benchmark_timings + total_benchmark_timings=self.total_benchmark_timings, ) explanation = Explanation( raw_explanation_message=best_optimization.candidate.explanation, @@ -283,7 +283,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: best_runtime_ns=best_optimization.runtime, function_name=function_to_optimize_qualified_name, file_path=self.function_to_optimize.file_path, - benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None + benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None, ) self.log_successful_optimization(explanation, generated_tests) @@ -328,7 +328,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: coverage_message=coverage_message, git_remote=self.args.git_remote, ) - if self.args.all or env_utils.get_pr_number(): + if self.args.all or env_utils.get_pr_number() or (self.args.file and not self.args.function): self.write_code_and_helpers( self.function_to_optimize_source_code, original_helper_code, @@ -397,8 +397,10 @@ def determine_best_candidate( if done and (future_line_profile_results is not None): line_profile_results = future_line_profile_results.result() candidates.extend(line_profile_results) - original_len+= len(candidates) - logger.info(f"Added results from line profiler to candidates, total candidates now: {original_len}") + original_len += len(candidates) + logger.info( + f"Added results from line profiler to candidates, total candidates now: {original_len}" + ) future_line_profile_results = None candidate_index += 1 candidate = candidates.popleft() @@ -425,7 +427,6 @@ def determine_best_candidate( ) continue - run_results = self.run_optimized_candidate( optimization_candidate_index=candidate_index, baseline_results=original_code_baseline, @@ -464,12 +465,19 @@ def determine_best_candidate( tree.add(f"Speedup ratio: {perf_gain + 1:.1f}X") replay_perf_gain = {} if self.args.benchmark: - test_results_by_benchmark = candidate_result.benchmarking_test_results.group_by_benchmarks(self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root) + test_results_by_benchmark = ( + candidate_result.benchmarking_test_results.group_by_benchmarks( + self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root + ) + ) if len(test_results_by_benchmark) > 0: benchmark_tree = Tree("Speedup percentage on benchmarks:") for benchmark_key, candidate_test_results in test_results_by_benchmark.items(): - - original_code_replay_runtime = original_code_baseline.replay_benchmarking_test_results[benchmark_key].total_passed_runtime() + original_code_replay_runtime = ( + original_code_baseline.replay_benchmarking_test_results[ + benchmark_key + ].total_passed_runtime() + ) candidate_replay_runtime = candidate_test_results.total_passed_runtime() replay_perf_gain[benchmark_key] = performance_gain( original_runtime_ns=original_code_replay_runtime, @@ -958,13 +966,17 @@ def establish_original_code_baseline( logger.debug(f"Total original code runtime (ns): {total_timing}") if self.args.benchmark: - replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks(self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root) + replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks( + self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root + ) return Success( ( OriginalCodeBaseline( behavioral_test_results=behavioral_results, benchmarking_test_results=benchmarking_results, - replay_benchmarking_test_results = replay_benchmarking_test_results if self.args.benchmark else None, + replay_benchmarking_test_results=replay_benchmarking_test_results + if self.args.benchmark + else None, runtime=total_timing, coverage_results=coverage_results, line_profile_results=line_profile_results, @@ -1077,16 +1089,22 @@ def run_optimized_candidate( logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}") if self.args.benchmark: - candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks(self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root) + candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks( + self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root + ) for benchmark_name, benchmark_results in candidate_replay_benchmarking_results.items(): - logger.debug(f"Benchmark {benchmark_name} runtime (ns): {humanize_runtime(benchmark_results.total_passed_runtime())}") + logger.debug( + f"Benchmark {benchmark_name} runtime (ns): {humanize_runtime(benchmark_results.total_passed_runtime())}" + ) return Success( OptimizedCandidateResult( max_loop_count=loop_count, best_test_runtime=total_candidate_timing, behavior_test_results=candidate_behavior_results, benchmarking_test_results=candidate_benchmarking_results, - replay_benchmarking_test_results = candidate_replay_benchmarking_results if self.args.benchmark else None, + replay_benchmarking_test_results=candidate_replay_benchmarking_results + if self.args.benchmark + else None, optimization_candidate_index=optimization_candidate_index, total_candidate_timing=total_candidate_timing, )