Skip to content

Commit e3f0fde

Browse files
prefer runtime speed over diff lines count
1 parent 24fb636 commit e3f0fde

File tree

2 files changed

+69
-6
lines changed

2 files changed

+69
-6
lines changed

codeflash/code_utils/code_utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,63 @@ def create_rank_dictionary_compact(int_array: list[int]) -> dict[int, int]:
6464
return {original_index: rank for rank, original_index in enumerate(sorted_indices)}
6565

6666

67+
def choose_weights(**importance: float) -> list[float]:
68+
"""Choose normalized weights from relative importance values.
69+
70+
Example:
71+
choose_weights(runtime=3, diff=1)
72+
-> [0.75, 0.25]
73+
74+
Args:
75+
**importance: keyword args of metric=importance (relative numbers).
76+
77+
Returns:
78+
A list of weights in the same order as the arguments.
79+
80+
"""
81+
total = sum(importance.values())
82+
if total == 0:
83+
raise ValueError("At least one importance value must be > 0")
84+
85+
return [v / total for v in importance.values()]
86+
87+
88+
def normalize(values: list[float]) -> list[float]:
89+
mn, mx = min(values), max(values)
90+
if mx == mn:
91+
return [0.0] * len(values)
92+
return [(v - mn) / (mx - mn) for v in values]
93+
94+
95+
def create_score_dictionary_from_metrics(weights: list[float], *metrics: list[float]) -> dict[int, int]:
96+
"""Combine multiple metrics into a single weighted score dictionary.
97+
98+
Each metric is a list of values (smaller = better).
99+
The total score for each index is the weighted sum of its values
100+
across all metrics:
101+
102+
score[index] = Σ (value * weight)
103+
104+
Args:
105+
weights: A list of weights, one per metric. Larger weight = more influence.
106+
*metrics: Lists of values (one list per metric, aligned by index).
107+
108+
Returns:
109+
A dictionary mapping each index to its combined weighted score.
110+
111+
"""
112+
if len(weights) != len(metrics):
113+
raise ValueError("Number of weights must match number of metrics")
114+
115+
combined: dict[int, float] = {}
116+
117+
for weight, metric in zip(weights, metrics):
118+
for idx, value in enumerate(metric):
119+
combined[idx] = combined.get(idx, 0) + value * weight
120+
121+
return combined
122+
123+
67124
@contextmanager
68125
def custom_addopts() -> None:
69126
pyproject_file = find_pyproject_toml()

codeflash/optimization/function_optimizer.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,15 @@
3131
)
3232
from codeflash.code_utils.code_utils import (
3333
ImportErrorPattern,
34+
choose_weights,
3435
cleanup_paths,
35-
create_rank_dictionary_compact,
36+
create_score_dictionary_from_metrics,
3637
diff_length,
3738
file_name_from_test_module_name,
3839
get_run_tmp_file,
3940
has_any_async_functions,
4041
module_name_from_file_path,
42+
normalize,
4143
restore_conftest,
4244
)
4345
from codeflash.code_utils.config_consts import (
@@ -683,11 +685,15 @@ def determine_best_candidate(
683685
diff_length(new_best_opt.candidate.source_code.flat, code_context.read_writable_code.flat)
684686
) # char level diff
685687
runtimes_list.append(new_best_opt.runtime)
686-
diff_lens_ranking = create_rank_dictionary_compact(diff_lens_list)
687-
runtimes_ranking = create_rank_dictionary_compact(runtimes_list)
688-
# TODO: better way to resolve conflicts with same min ranking
689-
overall_ranking = {key: diff_lens_ranking[key] + runtimes_ranking[key] for key in diff_lens_ranking.keys()} # noqa: SIM118
690-
min_key = min(overall_ranking, key=overall_ranking.get)
688+
689+
# runtime is more important than diff by a factor of 3
690+
weights = choose_weights(runtime=3, diff=1)
691+
692+
runtime_norm = normalize(runtimes_list)
693+
diffs_norm = normalize(diff_lens_list)
694+
score_dict = create_score_dictionary_from_metrics(weights, runtime_norm, diffs_norm)
695+
696+
min_key = min(score_dict, key=score_dict.get)
691697
best_optimization = valid_candidates_with_shorter_code[min_key]
692698
# reassign code string which is the shortest
693699
ai_service_client.log_results(

0 commit comments

Comments
 (0)