Skip to content

Commit b93fd34

Browse files
enhancements and cleanups
1 parent 6a9390c commit b93fd34

File tree

3 files changed

+72
-128
lines changed

3 files changed

+72
-128
lines changed

codeflash/api/aiservice.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,10 +308,10 @@ def optimize_python_code_repair(self, request: AIServiceCodeRepairRequest) -> Op
308308
console.rule()
309309
try:
310310
response = self.make_ai_service_request("/code_repair", payload=request, timeout=120)
311-
except requests.exceptions.RequestException as e:
311+
except (requests.exceptions.RequestException, TypeError) as e:
312312
logger.exception(f"Error generating optimization repair: {e}")
313313
ph("cli-optimize-error-caught", {"error": str(e)})
314-
return []
314+
return None
315315

316316
if response.status_code == 200:
317317
refined_optimization = response.json()

codeflash/models/models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,6 @@ class CodeContextType(str, Enum):
302302

303303

304304
class OptimizedCandidateResult(BaseModel):
305-
optimized_candidate: OptimizedCandidate
306305
max_loop_count: int
307306
best_test_runtime: int
308307
behavior_test_results: TestResults

codeflash/optimization/function_optimizer.py

Lines changed: 70 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import os
66
import queue
77
import random
8-
import sqlite3
98
import subprocess
109
import time
1110
import uuid
@@ -14,7 +13,6 @@
1413
from typing import TYPE_CHECKING
1514

1615
import libcst as cst
17-
import sentry_sdk
1816
from rich.console import Group
1917
from rich.panel import Panel
2018
from rich.syntax import Syntax
@@ -121,35 +119,6 @@
121119
from codeflash.verification.verification_utils import TestConfig
122120

123121

124-
def log_code_repair_to_db(
125-
code_repair_log_db: Path, optimization_id: str, trace_id: str, passed: str, faster: str
126-
) -> None:
127-
"""Log code repair data to SQLite database."""
128-
try:
129-
with sqlite3.connect(code_repair_log_db) as conn:
130-
cursor = conn.cursor()
131-
cursor.execute("""
132-
CREATE TABLE IF NOT EXISTS code_repair_logs_cf (
133-
optimization_id TEXT PRIMARY KEY,
134-
trace_id TEXT,
135-
passed TEXT,
136-
faster TEXT,
137-
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
138-
)
139-
""")
140-
cursor.execute(
141-
"""
142-
INSERT INTO code_repair_logs_cf (optimization_id, trace_id, passed, faster)
143-
VALUES (?, ?, ?, ?)
144-
""",
145-
(optimization_id, trace_id, passed, faster),
146-
)
147-
conn.commit()
148-
except Exception as e:
149-
sentry_sdk.capture_exception(e)
150-
logger.exception("Error logging code repair to db")
151-
152-
153122
class CandidateProcessor:
154123
"""Handles candidate processing using a queue-based approach."""
155124

@@ -281,8 +250,6 @@ def __init__(
281250
)
282251
self.optimization_review = ""
283252
self.ast_code_to_id = {}
284-
# SQLite database setup for logging
285-
self.code_repair_log_db = Path(__file__).parent / "code_repair_logs_cf.db"
286253

287254
def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]:
288255
should_run_experiment = self.experiment_id is not None
@@ -494,6 +461,41 @@ def optimize_function(self) -> Result[BestOptimization, str]:
494461
return Failure(f"No best optimizations found for function {self.function_to_optimize.qualified_name}")
495462
return Success(best_optimization)
496463

464+
def was_candidate_tested_before(self, normalized_code: str) -> bool:
465+
# check if this code has been evaluated before by checking the ast normalized code string
466+
return normalized_code in self.ast_code_to_id
467+
468+
def update_results_for_duplicate_candidate(
469+
self,
470+
candidate: OptimizedCandidate,
471+
code_context: CodeOptimizationContext,
472+
normalized_code: str,
473+
speedup_ratios: dict,
474+
is_correct: dict,
475+
optimized_runtimes: dict,
476+
optimized_line_profiler_results: dict,
477+
optimizations_post: dict,
478+
) -> None:
479+
logger.info("Current candidate has been encountered before in testing, Skipping optimization candidate.")
480+
past_opt_id = self.ast_code_to_id[normalized_code]["optimization_id"]
481+
# update speedup ratio, is_correct, optimizations_post, optimized_line_profiler_results, optimized_runtimes
482+
speedup_ratios[candidate.optimization_id] = speedup_ratios[past_opt_id]
483+
is_correct[candidate.optimization_id] = is_correct[past_opt_id]
484+
optimized_runtimes[candidate.optimization_id] = optimized_runtimes[past_opt_id]
485+
# line profiler results only available for successful runs
486+
if past_opt_id in optimized_line_profiler_results:
487+
optimized_line_profiler_results[candidate.optimization_id] = optimized_line_profiler_results[past_opt_id]
488+
optimizations_post[candidate.optimization_id] = self.ast_code_to_id[normalized_code][
489+
"shorter_source_code"
490+
].markdown
491+
optimizations_post[past_opt_id] = self.ast_code_to_id[normalized_code]["shorter_source_code"].markdown
492+
new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat)
493+
if (
494+
new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"]
495+
): # new candidate has a shorter diff than the previously encountered one
496+
self.ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code
497+
self.ast_code_to_id[normalized_code]["diff_len"] = new_diff_len
498+
497499
def determine_best_candidate(
498500
self,
499501
*,
@@ -573,78 +575,34 @@ def determine_best_candidate(
573575
logger.warning(
574576
"force_lsp|No functions were replaced in the optimized code. Skipping optimization candidate."
575577
)
576-
if candidate.optimization_id.endswith("cdrp"):
577-
log_code_repair_to_db(
578-
code_repair_log_db=self.code_repair_log_db,
579-
trace_id=self.function_trace_id[:-4] + exp_type,
580-
optimization_id=candidate.optimization_id,
581-
passed="no",
582-
faster="no",
583-
)
584578
console.rule()
585579
continue
586580
except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e:
587581
logger.error(e)
588582
self.write_code_and_helpers(
589583
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
590584
)
591-
if candidate.optimization_id.endswith("cdrp"):
592-
log_code_repair_to_db(
593-
code_repair_log_db=self.code_repair_log_db,
594-
trace_id=self.function_trace_id[:-4] + exp_type,
595-
optimization_id=candidate.optimization_id,
596-
passed="no",
597-
faster="no",
598-
)
599585
continue
600586
# check if this code has been evaluated before by checking the ast normalized code string
601587
normalized_code = normalize_code(candidate.source_code.flat.strip())
602-
if normalized_code in self.ast_code_to_id:
603-
logger.info(
604-
"Current candidate has been encountered before in testing, Skipping optimization candidate."
588+
if self.was_candidate_tested_before(normalized_code):
589+
self.update_results_for_duplicate_candidate(
590+
candidate=candidate,
591+
code_context=code_context,
592+
normalized_code=normalized_code,
593+
speedup_ratios=speedup_ratios,
594+
is_correct=is_correct,
595+
optimized_runtimes=optimized_runtimes,
596+
optimized_line_profiler_results=optimized_line_profiler_results,
597+
optimizations_post=optimizations_post,
605598
)
606-
past_opt_id = self.ast_code_to_id[normalized_code]["optimization_id"]
607-
# update speedup ratio, is_correct, optimizations_post, optimized_line_profiler_results, optimized_runtimes
608-
speedup_ratios[candidate.optimization_id] = speedup_ratios[past_opt_id]
609-
is_correct[candidate.optimization_id] = is_correct[past_opt_id]
610-
optimized_runtimes[candidate.optimization_id] = optimized_runtimes[past_opt_id]
611-
# line profiler results only available for successful runs
612-
if past_opt_id in optimized_line_profiler_results:
613-
optimized_line_profiler_results[candidate.optimization_id] = optimized_line_profiler_results[
614-
past_opt_id
615-
]
616-
optimizations_post[candidate.optimization_id] = self.ast_code_to_id[normalized_code][
617-
"shorter_source_code"
618-
].markdown
619-
optimizations_post[past_opt_id] = self.ast_code_to_id[normalized_code][
620-
"shorter_source_code"
621-
].markdown
622-
new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat)
623-
if (
624-
new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"]
625-
): # new candidate has a shorter diff than the previously encountered one
626-
self.ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code
627-
self.ast_code_to_id[normalized_code]["diff_len"] = new_diff_len
628-
if candidate.optimization_id.endswith("cdrp"):
629-
log_code_repair_to_db(
630-
code_repair_log_db=self.code_repair_log_db,
631-
trace_id=self.function_trace_id[:-4] + exp_type,
632-
optimization_id=candidate.optimization_id,
633-
passed="yes" if is_correct[candidate.optimization_id] else "no",
634-
faster="yes"
635-
if (
636-
speedup_ratios[candidate.optimization_id] is not None
637-
and speedup_ratios[candidate.optimization_id] > 0
638-
)
639-
else "no",
640-
)
641599
continue
642600
self.ast_code_to_id[normalized_code] = {
643601
"optimization_id": candidate.optimization_id,
644602
"shorter_source_code": candidate.source_code,
645603
"diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat),
646604
}
647-
run_results = self.run_optimized_candidate(
605+
run_results, new_candidate = self.run_optimized_candidate(
648606
optimization_candidate_index=candidate_index,
649607
baseline_results=original_code_baseline,
650608
original_helper_code=original_helper_code,
@@ -653,16 +611,17 @@ def determine_best_candidate(
653611
candidate=candidate,
654612
exp_type=exp_type,
655613
)
614+
if candidate.optimization_id != new_candidate.optimization_id:
615+
# override the candidate if the optimization_id has changed, this may happen if the candidate was modified by the code-repair
616+
candidate = new_candidate
617+
656618
console.rule()
657619
if not is_successful(run_results):
658620
optimized_runtimes[candidate.optimization_id] = None
659621
is_correct[candidate.optimization_id] = False
660622
speedup_ratios[candidate.optimization_id] = None
661623
else:
662624
candidate_result: OptimizedCandidateResult = run_results.unwrap()
663-
# override the candidate if the optimization_id has changed, this may happen if the candidate was modified by the code-repair
664-
if candidate.optimization_id != candidate_result.optimized_candidate.optimization_id:
665-
candidate = candidate_result.optimized_candidate
666625
best_test_runtime = candidate_result.best_test_runtime
667626
optimized_runtimes[candidate.optimization_id] = best_test_runtime
668627
is_correct[candidate.optimization_id] = True
@@ -745,20 +704,20 @@ def determine_best_candidate(
745704
)
746705
valid_optimizations.append(best_optimization)
747706
# # queue corresponding refined optimization for best optimization
748-
# if not candidate.optimization_id.endswith("refi"):
749-
# future_all_refinements.append(
750-
# self.refine_optimizations(
751-
# valid_optimizations=[best_optimization],
752-
# original_code_baseline=original_code_baseline,
753-
# code_context=code_context,
754-
# trace_id=self.function_trace_id[:-4] + exp_type
755-
# if self.experiment_id
756-
# else self.function_trace_id,
757-
# ai_service_client=ai_service_client,
758-
# executor=self.executor,
759-
# function_references=function_references,
760-
# )
761-
# )
707+
if not candidate.optimization_id.endswith("refi"):
708+
future_all_refinements.append(
709+
self.refine_optimizations(
710+
valid_optimizations=[best_optimization],
711+
original_code_baseline=original_code_baseline,
712+
code_context=code_context,
713+
trace_id=self.function_trace_id[:-4] + exp_type
714+
if self.experiment_id
715+
else self.function_trace_id,
716+
ai_service_client=ai_service_client,
717+
executor=self.executor,
718+
function_references=function_references,
719+
)
720+
)
762721
else:
763722
# For async functions, prioritize throughput metrics over runtime even for slow candidates
764723
is_async = (
@@ -793,19 +752,6 @@ def determine_best_candidate(
793752
if self.args.benchmark and benchmark_tree:
794753
console.print(benchmark_tree)
795754
console.rule()
796-
if candidate.optimization_id.endswith("cdrp"):
797-
log_code_repair_to_db(
798-
code_repair_log_db=self.code_repair_log_db,
799-
trace_id=self.function_trace_id[:-4] + exp_type,
800-
optimization_id=candidate.optimization_id,
801-
passed="yes" if is_correct[candidate.optimization_id] else "no",
802-
faster="yes"
803-
if (
804-
speedup_ratios[candidate.optimization_id] is not None
805-
and speedup_ratios[candidate.optimization_id] > 0
806-
)
807-
else "no",
808-
)
809755
except KeyboardInterrupt as e:
810756
logger.exception(f"Optimization interrupted: {e}")
811757
raise
@@ -1870,7 +1816,7 @@ def run_optimized_candidate(
18701816
code_context: CodeOptimizationContext,
18711817
candidate: OptimizedCandidate,
18721818
exp_type: str,
1873-
) -> Result[OptimizedCandidateResult, str]:
1819+
) -> tuple[Result[OptimizedCandidateResult, str], OptimizedCandidate]:
18741820
assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} # noqa: RUF018
18751821

18761822
with progress_bar("Testing optimization candidate"):
@@ -1928,11 +1874,11 @@ def run_optimized_candidate(
19281874
result_unmatched_perc = len(diffs) / len(candidate_behavior_results)
19291875
if result_unmatched_perc > 0.5:
19301876
# if the test unmatched percentage is greater than 50%, we can't fix it
1931-
return self.get_results_not_matched_error()
1877+
return self.get_results_not_matched_error(), candidate
19321878

19331879
if candidate.optimization_id.endswith("cdrp"):
19341880
# prevent looping for now
1935-
return self.get_results_not_matched_error()
1881+
return self.get_results_not_matched_error(), candidate
19361882

19371883
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
19381884

@@ -1948,7 +1894,7 @@ def run_optimized_candidate(
19481894
optimization_id=candidate.optimization_id,
19491895
)
19501896
if not new_candidate:
1951-
return Failure("Code repair failed to generate a valid candidate.")
1897+
return Failure("Code repair failed to generate a valid candidate."), candidate
19521898

19531899
code_print(new_candidate.source_code.flat)
19541900
normalized_code = normalize_code(new_candidate.source_code.flat.strip())
@@ -1983,7 +1929,7 @@ def run_optimized_candidate(
19831929
self.write_code_and_helpers(
19841930
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
19851931
)
1986-
return Failure("Code repair failed to generate a valid candidate.")
1932+
return Failure("Code repair failed to generate a valid candidate."), candidate
19871933

19881934
logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...")
19891935

@@ -2064,7 +2010,6 @@ def run_optimized_candidate(
20642010
)
20652011
return Success(
20662012
OptimizedCandidateResult(
2067-
optimized_candidate=candidate,
20682013
max_loop_count=loop_count,
20692014
best_test_runtime=total_candidate_timing,
20702015
behavior_test_results=candidate_behavior_results,
@@ -2076,7 +2021,7 @@ def run_optimized_candidate(
20762021
total_candidate_timing=total_candidate_timing,
20772022
async_throughput=candidate_async_throughput,
20782023
)
2079-
)
2024+
), candidate
20802025

20812026
def run_and_parse_tests(
20822027
self,

0 commit comments

Comments
 (0)