Skip to content

Commit e964ca6

Browse files
committed
works now, todo tiebreaking for same ranks
1 parent 42f0ada commit e964ca6

File tree

4 files changed

+75
-72
lines changed

4 files changed

+75
-72
lines changed

codeflash/api/aiservice.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def optimize_python_code_line_profiler( # noqa: D417
233233
console.rule()
234234
return []
235235

236-
def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]) -> dict[str, str]:
236+
def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]) -> list[OptimizedCandidate]:
237237
git_repo_owner, git_repo_name = safe_get_repo_owner_and_name()
238238
payload = [
239239
{
@@ -282,21 +282,28 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
282282
except requests.exceptions.RequestException as e:
283283
logger.exception(f"Error generating optimization refinements: {e}")
284284
ph("cli-optimize-error-caught", {"error": str(e)})
285-
return {}
285+
return []
286286

287287
if response.status_code == 200:
288-
refined_optimizations = response.json()["result"]
288+
refined_optimizations = response.json()["refinements"]
289289
logger.info(f"Generated {len(refined_optimizations)} candidate refinements.")
290290
console.rule()
291-
return refined_optimizations
291+
return [
292+
OptimizedCandidate(
293+
source_code=opt["source_code"],
294+
explanation=opt["explanation"],
295+
optimization_id=opt["optimization_id"][:-4]+"refi",
296+
)
297+
for opt in refined_optimizations
298+
]
292299
try:
293300
error = response.json()["error"]
294301
except Exception:
295302
error = response.text
296303
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
297304
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
298305
console.rule()
299-
return {}
306+
return []
300307

301308
def log_results( # noqa: D417
302309
self,
@@ -306,6 +313,7 @@ def log_results( # noqa: D417
306313
optimized_runtime: dict[str, float | None] | None,
307314
is_correct: dict[str, bool] | None,
308315
best_optimization_id: str | None,
316+
optimized_line_profiler_results: dict[str, str] | None,
309317
) -> None:
310318
"""Log features to the database.
311319
@@ -317,6 +325,7 @@ def log_results( # noqa: D417
317325
- optimized_runtime (Optional[Dict[str, float]]): The optimized runtime.
318326
- is_correct (Optional[Dict[str, bool]]): Whether the optimized code is correct.
319327
- best_optimization_id (Optional[str]): The best optimization id.
328+
-optimized_line_profiler_results: line_profiler results for every candidate mapped to their optimization_id
320329
321330
"""
322331
payload = {
@@ -327,6 +336,7 @@ def log_results( # noqa: D417
327336
"is_correct": is_correct,
328337
"codeflash_version": codeflash_version,
329338
"best_optimization_id": best_optimization_id,
339+
"optimized_line_profiler_results": optimized_line_profiler_results
330340
}
331341
try:
332342
self.make_ai_service_request("/log_features", payload=payload, timeout=5)

codeflash/code_utils/code_utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import ast
4+
import difflib
45
import os
56
import re
67
import shutil
@@ -18,6 +19,48 @@
1819

1920
ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE)
2021

22+
def diff_length(a: str, b: str) -> int:
23+
"""Compute the length (in characters) of the unified diff between two strings.
24+
25+
Args:
26+
a (str): Original string.
27+
b (str): Modified string.
28+
29+
Returns:
30+
int: Total number of characters in the diff.
31+
32+
"""
33+
# Split input strings into lines for line-by-line diff
34+
a_lines = a.splitlines(keepends=True)
35+
b_lines = b.splitlines(keepends=True)
36+
37+
# Compute unified diff
38+
diff_lines = list(difflib.unified_diff(a_lines, b_lines, lineterm=""))
39+
40+
# Join all lines with newline to calculate total diff length
41+
diff_text = "\n".join(diff_lines)
42+
43+
return len(diff_text)
44+
45+
46+
def create_rank_dictionary_compact(int_array: list[int]) -> dict[int, int]:
47+
"""Create a dictionary from a list of ints, mapping the original index to its rank.
48+
49+
This version uses a more compact, "Pythonic" implementation.
50+
51+
Args:
52+
int_array: A list of integers.
53+
54+
Returns:
55+
A dictionary where keys are original indices and values are the
56+
rank of the element in ascending order.
57+
58+
"""
59+
# Sort the indices of the array based on their corresponding values
60+
sorted_indices = sorted(range(len(int_array)), key=lambda i: int_array[i])
61+
62+
# Create a dictionary mapping the original index to its rank (its position in the sorted list)
63+
return {original_index: rank for rank, original_index in enumerate(sorted_indices)}
2164

2265
@contextmanager
2366
def custom_addopts() -> None:

codeflash/optimization/function_optimizer.py

Lines changed: 16 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
has_any_async_functions,
3838
module_name_from_file_path,
3939
restore_conftest,
40+
diff_length,
41+
create_rank_dictionary_compact,
4042
)
4143
from codeflash.code_utils.config_consts import (
4244
INDIVIDUAL_TESTCASE_TIMEOUT,
@@ -369,6 +371,7 @@ def determine_best_candidate(
369371
speedup_ratios: dict[str, float | None] = {}
370372
optimized_runtimes: dict[str, float | None] = {}
371373
is_correct = {}
374+
optimized_line_profiler_results: dict[str, str] = {}
372375

373376
logger.info(
374377
f"Determining best optimization candidate (out of {len(candidates)}) for "
@@ -464,7 +467,7 @@ def determine_best_candidate(
464467
candidate_result, original_code_baseline.runtime, best_runtime_until_now=None
465468
) and quantity_of_tests_critic(candidate_result):
466469
tree.add(
467-
"This candidate is faster than the previous best candidate. 🚀"
470+
"This candidate is faster than the original code. 🚀"
468471
) # TODO: Change this description
469472
tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}")
470473
tree.add(
@@ -479,6 +482,7 @@ def determine_best_candidate(
479482
original_helper_code=original_helper_code,
480483
candidate_index=candidate_index,
481484
)
485+
optimized_line_profiler_results[candidate.optimization_id]=line_profile_test_results['str_out']
482486
replay_perf_gain = {}
483487
if self.args.benchmark:
484488
test_results_by_benchmark = (
@@ -547,8 +551,8 @@ def determine_best_candidate(
547551
trace_id = self.function_trace_id
548552
if trace_id.endswith(("EXP0", "EXP1")):
549553
trace_id = trace_id[:-4] + exp_type
550-
# refinement_dict is a dictionary with optimization_id as a key and the refined code as a value
551-
refinement_dict = self.refine_optimizations(
554+
# refinement_response is a dataclass with optimization_id, code and explanation
555+
refinement_response = self.refine_optimizations(
552556
valid_optimizations=self.valid_optimizations,
553557
original_code_baseline=original_code_baseline,
554558
code_context=code_context,
@@ -562,23 +566,9 @@ def determine_best_candidate(
562566
executor=executor,
563567
fto_name=self.function_to_optimize.qualified_name,
564568
)
565-
566-
more_opt_candidates = [
567-
OptimizedCandidate(
568-
source_code=code,
569-
explanation=self.valid_optimizations[
570-
i
571-
].candidate.explanation, # TODO: handle the new explanation after the refinement
572-
optimization_id=opt_id,
573-
)
574-
for i, (opt_id, code) in enumerate(refinement_dict.items())
575-
# filter out empty strings of code
576-
if code != ""
577-
]
578-
# we no longer need to apply diffs since we are generating the entire code again
579-
candidates.extend(more_opt_candidates)
580-
print("added candidates from refinement")
581-
original_len += len(more_opt_candidates)
569+
candidates.extend(refinement_response)
570+
print("Added candidates from refinement")
571+
original_len += len(refinement_response)
582572
refinement_done = True
583573
except KeyboardInterrupt as e:
584574
self.write_code_and_helpers(
@@ -587,58 +577,17 @@ def determine_best_candidate(
587577
logger.exception(f"Optimization interrupted: {e}")
588578
raise
589579

590-
def diff_length(a: str, b: str) -> int:
591-
"""Compute the length (in characters) of the unified diff between two strings.
592-
593-
Args:
594-
a (str): Original string.
595-
b (str): Modified string.
596-
597-
Returns:
598-
int: Total number of characters in the diff.
599-
600-
"""
601-
# Split input strings into lines for line-by-line diff
602-
a_lines = a.splitlines(keepends=True)
603-
b_lines = b.splitlines(keepends=True)
604-
605-
# Compute unified diff
606-
diff_lines = list(difflib.unified_diff(a_lines, b_lines, lineterm=""))
607-
608-
# Join all lines with newline to calculate total diff length
609-
diff_text = "\n".join(diff_lines)
610-
611-
return len(diff_text)
612-
613-
def create_rank_dictionary_compact(int_array: list[int]) -> dict[int, int]:
614-
"""Create a dictionary from a list of ints, mapping the original index to its rank.
615-
616-
This version uses a more compact, "Pythonic" implementation.
617-
618-
Args:
619-
int_array: A list of integers.
620-
621-
Returns:
622-
A dictionary where keys are original indices and values are the
623-
rank of the element in ascending order.
624-
625-
"""
626-
# Sort the indices of the array based on their corresponding values
627-
sorted_indices = sorted(range(len(int_array)), key=lambda i: int_array[i])
628-
629-
# Create a dictionary mapping the original index to its rank (its position in the sorted list)
630-
return {original_index: rank for rank, original_index in enumerate(sorted_indices)}
631-
632580
if not len(self.valid_optimizations):
633581
return None
634582
# need to figure out the best candidate here before we return best_optimization
635-
diff_lens_list = []
583+
diff_lens_list = [] # character level diff
636584
runtimes_list = []
637585
for valid_opt in self.valid_optimizations:
638-
diff_lens_list.append(diff_length(valid_opt.candidate.source_code, code_context.read_writable_code))
586+
diff_lens_list.append(diff_length(valid_opt.candidate.source_code, code_context.read_writable_code)) #char level diff
639587
runtimes_list.append(valid_opt.runtime)
640588
diff_lens_ranking = create_rank_dictionary_compact(diff_lens_list)
641589
runtimes_ranking = create_rank_dictionary_compact(runtimes_list)
590+
# TODO: better way to resolve conflicts with same min ranking
642591
overall_ranking = {key: diff_lens_ranking[key] + runtimes_ranking[key] for key in diff_lens_ranking.keys()} # noqa: SIM118
643592
min_key = min(overall_ranking, key=overall_ranking.get)
644593
best_optimization = self.valid_optimizations[min_key]
@@ -649,6 +598,7 @@ def create_rank_dictionary_compact(int_array: list[int]) -> dict[int, int]:
649598
optimized_runtime=optimized_runtimes,
650599
is_correct=is_correct,
651600
best_optimization_id=best_optimization.candidate.optimization_id,
601+
optimized_line_profiler_results= optimized_line_profiler_results
652602
)
653603
return best_optimization
654604

@@ -662,7 +612,7 @@ def refine_optimizations(
662612
ai_service_client: AiServiceClient,
663613
executor: concurrent.futures.ThreadPoolExecutor,
664614
fto_name: str,
665-
) -> dict[str, str]:
615+
) -> list[OptimizedCandidate]:
666616
request = [
667617
AIServiceRefinerRequest(
668618
optimization_id=opt.candidate.optimization_id,
@@ -680,7 +630,7 @@ def refine_optimizations(
680630
fto_name=fto_name,
681631
)
682632
for opt in valid_optimizations
683-
]
633+
] # TODO: multiple workers for this?
684634
future_refinement_results = executor.submit(ai_service_client.optimize_python_code_refinement, request=request)
685635
concurrent.futures.wait([future_refinement_results])
686636
return future_refinement_results.result()

codeflash/result/critic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def speedup_critic(
4848
original_runtime_ns=original_code_runtime, optimized_runtime_ns=candidate_result.best_test_runtime
4949
)
5050
if best_runtime_until_now is None:
51-
# collect all optimizations with thi
51+
# collect all optimizations with this
5252
return bool(perf_gain > noise_floor)
5353
return bool(perf_gain > noise_floor and candidate_result.best_test_runtime < best_runtime_until_now)
5454

0 commit comments

Comments
 (0)