Skip to content

Commit 8c36180

Browse files
committed
todo logging message, db logging
1 parent b205516 commit 8c36180

File tree

2 files changed

+47
-13
lines changed

2 files changed

+47
-13
lines changed

codeflash/code_utils/code_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,23 @@
2020
ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE)
2121

2222

23+
def unified_diff_strings(code1: str, code2: str, fromfile: str = "original", tofile: str = "modified") -> str:
24+
"""Return the unified diff between two code strings as a single string.
25+
26+
:param code1: First code string (original).
27+
:param code2: Second code string (modified).
28+
:param fromfile: Label for the first code string.
29+
:param tofile: Label for the second code string.
30+
:return: Unified diff as a string.
31+
"""
32+
code1_lines = code1.splitlines(keepends=True)
33+
code2_lines = code2.splitlines(keepends=True)
34+
35+
diff = difflib.unified_diff(code1_lines, code2_lines, fromfile=fromfile, tofile=tofile, lineterm="")
36+
37+
return "".join(diff)
38+
39+
2340
def diff_length(a: str, b: str) -> int:
2441
"""Compute the length (in characters) of the unified diff between two strings.
2542

codeflash/optimization/function_optimizer.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
has_any_async_functions,
4040
module_name_from_file_path,
4141
restore_conftest,
42+
unified_diff_strings,
4243
)
4344
from codeflash.code_utils.config_consts import (
4445
INDIVIDUAL_TESTCASE_TIMEOUT,
@@ -656,17 +657,12 @@ def determine_best_candidate(
656657
if not valid_optimizations:
657658
return None
658659
# need to figure out the best candidate here before we return best_optimization
659-
ranking = self.executor.submit(
660-
ai_service_client.generate_ranking,
661-
diffs=[],
662-
optimization_ids=[],
663-
speedups=[],
664-
trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
665-
)
666-
print(ranking)
667660
# reassign the shorter code here
668661
valid_candidates_with_shorter_code = []
669662
diff_lens_list = [] # character level diff
663+
speedups_list = []
664+
optimization_ids = []
665+
diff_strs = []
670666
runtimes_list = []
671667
for valid_opt in valid_optimizations:
672668
valid_opt_normalized_code = ast.unparse(ast.parse(valid_opt.candidate.source_code.flat.strip()))
@@ -690,12 +686,33 @@ def determine_best_candidate(
690686
diff_lens_list.append(
691687
diff_length(new_best_opt.candidate.source_code.flat, code_context.read_writable_code.flat)
692688
) # char level diff
689+
diff_strs.append(
690+
unified_diff_strings(code_context.read_writable_code.flat, new_best_opt.candidate.source_code.flat)
691+
)
692+
speedups_list.append(
693+
1
694+
+ performance_gain(
695+
original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=new_best_opt.runtime
696+
)
697+
)
698+
optimization_ids.append(new_best_opt.candidate.optimization_id)
693699
runtimes_list.append(new_best_opt.runtime)
694-
diff_lens_ranking = create_rank_dictionary_compact(diff_lens_list)
695-
runtimes_ranking = create_rank_dictionary_compact(runtimes_list)
696-
# TODO: better way to resolve conflicts with same min ranking
697-
overall_ranking = {key: diff_lens_ranking[key] + runtimes_ranking[key] for key in diff_lens_ranking.keys()} # noqa: SIM118
698-
min_key = min(overall_ranking, key=overall_ranking.get)
700+
ranking = self.executor.submit(
701+
ai_service_client.generate_ranking,
702+
diffs=diff_strs,
703+
optimization_ids=optimization_ids,
704+
speedups=speedups_list,
705+
trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
706+
)
707+
ranking = [x - 1 for x in ranking]
708+
if ranking:
709+
min_key = ranking[0]
710+
else:
711+
diff_lens_ranking = create_rank_dictionary_compact(diff_lens_list)
712+
runtimes_ranking = create_rank_dictionary_compact(runtimes_list)
713+
# TODO: better way to resolve conflicts with same min ranking
714+
overall_ranking = {key: diff_lens_ranking[key] + runtimes_ranking[key] for key in diff_lens_ranking.keys()} # noqa: SIM118
715+
min_key = min(overall_ranking, key=overall_ranking.get)
699716
best_optimization = valid_candidates_with_shorter_code[min_key]
700717
# reassign code string which is the shortest
701718
ai_service_client.log_results(

0 commit comments

Comments
 (0)