Skip to content

Commit 4571b2d

Browse files
committed
clean previous optimizations when using --file mode optimizations
1 parent 48967e6 commit 4571b2d

File tree

1 file changed

+33
-15
lines changed

1 file changed

+33
-15
lines changed

codeflash/optimization/function_optimizer.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
function_benchmark_timings: dict[BenchmarkKey, int] | None = None,
9595
total_benchmark_timings: dict[BenchmarkKey, int] | None = None,
9696
args: Namespace | None = None,
97-
replay_tests_dir: Path|None = None
97+
replay_tests_dir: Path | None = None,
9898
) -> None:
9999
self.project_root = test_cfg.project_root_path
100100
self.test_cfg = test_cfg
@@ -273,7 +273,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
273273
processed_benchmark_info = process_benchmark_data(
274274
replay_performance_gain=best_optimization.replay_performance_gain,
275275
fto_benchmark_timings=self.function_benchmark_timings,
276-
total_benchmark_timings=self.total_benchmark_timings
276+
total_benchmark_timings=self.total_benchmark_timings,
277277
)
278278
explanation = Explanation(
279279
raw_explanation_message=best_optimization.candidate.explanation,
@@ -283,7 +283,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
283283
best_runtime_ns=best_optimization.runtime,
284284
function_name=function_to_optimize_qualified_name,
285285
file_path=self.function_to_optimize.file_path,
286-
benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None
286+
benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None,
287287
)
288288

289289
self.log_successful_optimization(explanation, generated_tests)
@@ -328,7 +328,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
328328
coverage_message=coverage_message,
329329
git_remote=self.args.git_remote,
330330
)
331-
if self.args.all or env_utils.get_pr_number():
331+
if self.args.all or env_utils.get_pr_number() or (self.args.file and not self.args.function):
332332
self.write_code_and_helpers(
333333
self.function_to_optimize_source_code,
334334
original_helper_code,
@@ -397,8 +397,10 @@ def determine_best_candidate(
397397
if done and (future_line_profile_results is not None):
398398
line_profile_results = future_line_profile_results.result()
399399
candidates.extend(line_profile_results)
400-
original_len+= len(candidates)
401-
logger.info(f"Added results from line profiler to candidates, total candidates now: {original_len}")
400+
original_len += len(candidates)
401+
logger.info(
402+
f"Added results from line profiler to candidates, total candidates now: {original_len}"
403+
)
402404
future_line_profile_results = None
403405
candidate_index += 1
404406
candidate = candidates.popleft()
@@ -425,7 +427,6 @@ def determine_best_candidate(
425427
)
426428
continue
427429

428-
429430
run_results = self.run_optimized_candidate(
430431
optimization_candidate_index=candidate_index,
431432
baseline_results=original_code_baseline,
@@ -464,12 +465,19 @@ def determine_best_candidate(
464465
tree.add(f"Speedup ratio: {perf_gain + 1:.1f}X")
465466
replay_perf_gain = {}
466467
if self.args.benchmark:
467-
test_results_by_benchmark = candidate_result.benchmarking_test_results.group_by_benchmarks(self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root)
468+
test_results_by_benchmark = (
469+
candidate_result.benchmarking_test_results.group_by_benchmarks(
470+
self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root
471+
)
472+
)
468473
if len(test_results_by_benchmark) > 0:
469474
benchmark_tree = Tree("Speedup percentage on benchmarks:")
470475
for benchmark_key, candidate_test_results in test_results_by_benchmark.items():
471-
472-
original_code_replay_runtime = original_code_baseline.replay_benchmarking_test_results[benchmark_key].total_passed_runtime()
476+
original_code_replay_runtime = (
477+
original_code_baseline.replay_benchmarking_test_results[
478+
benchmark_key
479+
].total_passed_runtime()
480+
)
473481
candidate_replay_runtime = candidate_test_results.total_passed_runtime()
474482
replay_perf_gain[benchmark_key] = performance_gain(
475483
original_runtime_ns=original_code_replay_runtime,
@@ -958,13 +966,17 @@ def establish_original_code_baseline(
958966
logger.debug(f"Total original code runtime (ns): {total_timing}")
959967

960968
if self.args.benchmark:
961-
replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks(self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root)
969+
replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks(
970+
self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root
971+
)
962972
return Success(
963973
(
964974
OriginalCodeBaseline(
965975
behavioral_test_results=behavioral_results,
966976
benchmarking_test_results=benchmarking_results,
967-
replay_benchmarking_test_results = replay_benchmarking_test_results if self.args.benchmark else None,
977+
replay_benchmarking_test_results=replay_benchmarking_test_results
978+
if self.args.benchmark
979+
else None,
968980
runtime=total_timing,
969981
coverage_results=coverage_results,
970982
line_profile_results=line_profile_results,
@@ -1077,16 +1089,22 @@ def run_optimized_candidate(
10771089

10781090
logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}")
10791091
if self.args.benchmark:
1080-
candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks(self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root)
1092+
candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks(
1093+
self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root
1094+
)
10811095
for benchmark_name, benchmark_results in candidate_replay_benchmarking_results.items():
1082-
logger.debug(f"Benchmark {benchmark_name} runtime (ns): {humanize_runtime(benchmark_results.total_passed_runtime())}")
1096+
logger.debug(
1097+
f"Benchmark {benchmark_name} runtime (ns): {humanize_runtime(benchmark_results.total_passed_runtime())}"
1098+
)
10831099
return Success(
10841100
OptimizedCandidateResult(
10851101
max_loop_count=loop_count,
10861102
best_test_runtime=total_candidate_timing,
10871103
behavior_test_results=candidate_behavior_results,
10881104
benchmarking_test_results=candidate_benchmarking_results,
1089-
replay_benchmarking_test_results = candidate_replay_benchmarking_results if self.args.benchmark else None,
1105+
replay_benchmarking_test_results=candidate_replay_benchmarking_results
1106+
if self.args.benchmark
1107+
else None,
10901108
optimization_candidate_index=optimization_candidate_index,
10911109
total_candidate_timing=total_candidate_timing,
10921110
)

0 commit comments

Comments
 (0)