Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 33 additions & 15 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
function_benchmark_timings: dict[BenchmarkKey, int] | None = None,
total_benchmark_timings: dict[BenchmarkKey, int] | None = None,
args: Namespace | None = None,
replay_tests_dir: Path|None = None
replay_tests_dir: Path | None = None,
) -> None:
self.project_root = test_cfg.project_root_path
self.test_cfg = test_cfg
Expand Down Expand Up @@ -273,7 +273,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
processed_benchmark_info = process_benchmark_data(
replay_performance_gain=best_optimization.replay_performance_gain,
fto_benchmark_timings=self.function_benchmark_timings,
total_benchmark_timings=self.total_benchmark_timings
total_benchmark_timings=self.total_benchmark_timings,
)
explanation = Explanation(
raw_explanation_message=best_optimization.candidate.explanation,
Expand All @@ -283,7 +283,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
best_runtime_ns=best_optimization.runtime,
function_name=function_to_optimize_qualified_name,
file_path=self.function_to_optimize.file_path,
benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None
benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None,
)

self.log_successful_optimization(explanation, generated_tests)
Expand Down Expand Up @@ -328,7 +328,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
coverage_message=coverage_message,
git_remote=self.args.git_remote,
)
if self.args.all or env_utils.get_pr_number():
if self.args.all or env_utils.get_pr_number() or (self.args.file and not self.args.function):
self.write_code_and_helpers(
self.function_to_optimize_source_code,
original_helper_code,
Expand Down Expand Up @@ -397,8 +397,10 @@ def determine_best_candidate(
if done and (future_line_profile_results is not None):
line_profile_results = future_line_profile_results.result()
candidates.extend(line_profile_results)
original_len+= len(candidates)
logger.info(f"Added results from line profiler to candidates, total candidates now: {original_len}")
original_len += len(candidates)
logger.info(
f"Added results from line profiler to candidates, total candidates now: {original_len}"
)
future_line_profile_results = None
candidate_index += 1
candidate = candidates.popleft()
Expand All @@ -425,7 +427,6 @@ def determine_best_candidate(
)
continue


run_results = self.run_optimized_candidate(
optimization_candidate_index=candidate_index,
baseline_results=original_code_baseline,
Expand Down Expand Up @@ -464,12 +465,19 @@ def determine_best_candidate(
tree.add(f"Speedup ratio: {perf_gain + 1:.1f}X")
replay_perf_gain = {}
if self.args.benchmark:
test_results_by_benchmark = candidate_result.benchmarking_test_results.group_by_benchmarks(self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root)
test_results_by_benchmark = (
candidate_result.benchmarking_test_results.group_by_benchmarks(
self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root
)
)
if len(test_results_by_benchmark) > 0:
benchmark_tree = Tree("Speedup percentage on benchmarks:")
for benchmark_key, candidate_test_results in test_results_by_benchmark.items():

original_code_replay_runtime = original_code_baseline.replay_benchmarking_test_results[benchmark_key].total_passed_runtime()
original_code_replay_runtime = (
original_code_baseline.replay_benchmarking_test_results[
benchmark_key
].total_passed_runtime()
)
candidate_replay_runtime = candidate_test_results.total_passed_runtime()
replay_perf_gain[benchmark_key] = performance_gain(
original_runtime_ns=original_code_replay_runtime,
Expand Down Expand Up @@ -958,13 +966,17 @@ def establish_original_code_baseline(
logger.debug(f"Total original code runtime (ns): {total_timing}")

if self.args.benchmark:
replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks(self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root)
replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks(
self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root
)
return Success(
(
OriginalCodeBaseline(
behavioral_test_results=behavioral_results,
benchmarking_test_results=benchmarking_results,
replay_benchmarking_test_results = replay_benchmarking_test_results if self.args.benchmark else None,
replay_benchmarking_test_results=replay_benchmarking_test_results
if self.args.benchmark
else None,
runtime=total_timing,
coverage_results=coverage_results,
line_profile_results=line_profile_results,
Expand Down Expand Up @@ -1077,16 +1089,22 @@ def run_optimized_candidate(

logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}")
if self.args.benchmark:
candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks(self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root)
candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks(
self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root
)
for benchmark_name, benchmark_results in candidate_replay_benchmarking_results.items():
logger.debug(f"Benchmark {benchmark_name} runtime (ns): {humanize_runtime(benchmark_results.total_passed_runtime())}")
logger.debug(
f"Benchmark {benchmark_name} runtime (ns): {humanize_runtime(benchmark_results.total_passed_runtime())}"
)
return Success(
OptimizedCandidateResult(
max_loop_count=loop_count,
best_test_runtime=total_candidate_timing,
behavior_test_results=candidate_behavior_results,
benchmarking_test_results=candidate_benchmarking_results,
replay_benchmarking_test_results = candidate_replay_benchmarking_results if self.args.benchmark else None,
replay_benchmarking_test_results=candidate_replay_benchmarking_results
if self.args.benchmark
else None,
optimization_candidate_index=optimization_candidate_index,
total_candidate_timing=total_candidate_timing,
)
Expand Down
Loading