Skip to content

Commit 1ddc87c

Browse files
Merge branch 'feat/feedback-loop-for-unmatched-test-results' of github.com:codeflash-ai/codeflash into feat/feedback-loop-for-unmatched-test-results
2 parents 9f7ed90 + 0325444 commit 1ddc87c

File tree

2 files changed

+124
-11
lines changed

2 files changed

+124
-11
lines changed

codeflash/api/aiservice.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ def optimize_python_code_repair(self, request: list[AIServiceCodeRepairRequest])
310310
"optimization_id": opt.optimization_id,
311311
"original_source_code": opt.original_source_code,
312312
"modified_source_code": opt.modified_source_code,
313+
"test_details": opt.test_details,
313314
"trace_id": opt.trace_id,
314315
}
315316
for opt in request
@@ -325,7 +326,7 @@ def optimize_python_code_repair(self, request: list[AIServiceCodeRepairRequest])
325326

326327
if response.status_code == 200:
327328
refined_optimizations = response.json()["code_repairs"]
328-
logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.")
329+
# logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.")
329330
console.rule()
330331

331332
refinements = self._get_valid_candidates(refined_optimizations)

codeflash/optimization/function_optimizer.py

Lines changed: 122 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import queue
77
import random
8+
import sqlite3
89
import subprocess
910
import time
1011
import uuid
@@ -119,6 +120,61 @@
119120
from codeflash.verification.verification_utils import TestConfig
120121

121122

123+
def log_code_repair_to_db(
124+
code_repair_log_db: Path,
125+
optimization_id: str,
126+
trace_id: str | None = None,
127+
passed: str | None = None,
128+
faster: str | None = None,
129+
) -> None:
130+
"""Log code repair data to SQLite database.
131+
132+
Uses upsert pattern to allow incremental logging with different columns at different places.
133+
Only non-None values will be updated; existing values are preserved.
134+
"""
135+
try:
136+
conn = sqlite3.connect(code_repair_log_db)
137+
cursor = conn.cursor()
138+
139+
# Build dynamic upsert query based on provided columns
140+
columns = ["optimization_id"]
141+
values = [optimization_id]
142+
update_parts = ["updated_at = CURRENT_TIMESTAMP"]
143+
144+
if trace_id is not None:
145+
columns.append("trace_id")
146+
values.append(trace_id)
147+
update_parts.append("trace_id = excluded.trace_id")
148+
149+
if passed is not None:
150+
columns.append("passed")
151+
values.append(passed)
152+
update_parts.append("passed = excluded.passed")
153+
154+
if faster is not None:
155+
columns.append("faster")
156+
values.append(faster)
157+
update_parts.append("faster = excluded.faster")
158+
159+
placeholders = ", ".join(["?"] * len(values))
160+
columns_str = ", ".join(columns)
161+
update_str = ", ".join(update_parts)
162+
163+
cursor.execute(
164+
f"""
165+
INSERT INTO code_repair_logs_cf ({columns_str})
166+
VALUES ({placeholders})
167+
ON CONFLICT(optimization_id) DO UPDATE SET {update_str}
168+
""", # noqa: S608
169+
values,
170+
)
171+
conn.commit()
172+
conn.close()
173+
except Exception as e:
174+
sentry_sdk.capture_exception(e)
175+
logger.exception(e)
176+
177+
122178
class CandidateProcessor:
123179
"""Handles candidate processing using a queue-based approach."""
124180

@@ -249,6 +305,8 @@ def __init__(
249305
max_workers=n_tests + 3 if self.experiment_id is None else n_tests + 4
250306
)
251307
self.optimization_review = ""
308+
# SQLite database setup for logging
309+
self.code_repair_log_db = Path(__file__).parent / "code_repair_logs_cf.db"
252310

253311
def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]:
254312
should_run_experiment = self.experiment_id is not None
@@ -389,7 +447,20 @@ def optimize_function(self) -> Result[BestOptimization, str]:
389447
initialization_result = self.can_be_optimized()
390448
if not is_successful(initialization_result):
391449
return Failure(initialization_result.failure())
392-
450+
conn = sqlite3.connect(self.code_repair_log_db)
451+
cursor = conn.cursor()
452+
cursor.execute("""
453+
CREATE TABLE IF NOT EXISTS code_repair_logs_cf (
454+
optimization_id TEXT PRIMARY KEY,
455+
trace_id TEXT,
456+
passed TEXT,
457+
faster TEXT,
458+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
459+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
460+
)
461+
""")
462+
conn.commit()
463+
conn.close()
393464
should_run_experiment, code_context, original_helper_code = initialization_result.unwrap()
394465

395466
code_print(
@@ -540,13 +611,29 @@ def determine_best_candidate(
540611
logger.warning(
541612
"force_lsp|No functions were replaced in the optimized code. Skipping optimization candidate."
542613
)
614+
if candidate.optimization_id.endswith("cdrp"):
615+
log_code_repair_to_db(
616+
code_repair_log_db=self.code_repair_log_db,
617+
trace_id=self.function_trace_id[:-4] + exp_type,
618+
optimization_id=candidate.optimization_id,
619+
passed="no",
620+
faster="no",
621+
)
543622
console.rule()
544623
continue
545624
except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e:
546625
logger.error(e)
547626
self.write_code_and_helpers(
548627
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
549628
)
629+
if candidate.optimization_id.endswith("cdrp"):
630+
log_code_repair_to_db(
631+
code_repair_log_db=self.code_repair_log_db,
632+
trace_id=self.function_trace_id[:-4] + exp_type,
633+
optimization_id=candidate.optimization_id,
634+
passed="no",
635+
faster="no",
636+
)
550637
continue
551638
# check if this code has been evaluated before by checking the ast normalized code string
552639
normalized_code = normalize_code(candidate.source_code.flat.strip())
@@ -574,6 +661,19 @@ def determine_best_candidate(
574661
): # new candidate has a shorter diff than the previously encountered one
575662
ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code
576663
ast_code_to_id[normalized_code]["diff_len"] = new_diff_len
664+
if candidate.optimization_id.endswith("cdrp"):
665+
log_code_repair_to_db(
666+
code_repair_log_db=self.code_repair_log_db,
667+
trace_id=self.function_trace_id[:-4] + exp_type,
668+
optimization_id=candidate.optimization_id,
669+
passed="yes" if is_correct[candidate.optimization_id] else "no",
670+
faster="yes"
671+
if (
672+
speedup_ratios[candidate.optimization_id] is not None
673+
and speedup_ratios[candidate.optimization_id] > 0
674+
)
675+
else "no",
676+
)
577677
continue
578678
ast_code_to_id[normalized_code] = {
579679
"optimization_id": candidate.optimization_id,
@@ -593,24 +693,22 @@ def determine_best_candidate(
593693
speedup_ratios[candidate.optimization_id] = None
594694
fail_value = run_results.value
595695
if (
596-
fail_value != "Test results did not match the test results of the original code."
696+
fail_value.strip() != "Test results did not match the test results of the original code."
597697
and len(future_all_refinements) <= 3
598698
and not candidate.optimization_id.endswith("cdrp")
599699
):
600700
# # queue corresponding code repair optimization for best optimization
601701
future_all_refinements.append(
602702
self.code_repair_optimizations(
603-
original_source_code=candidate,
604-
modified_source_code=code_context,
605-
original_code_baseline=original_code_baseline,
606-
test_details="test_details",
607-
code_context=code_context,
703+
original_source_code=code_context.read_writable_code.markdown,
704+
modified_source_code=candidate.source_code.markdown,
705+
test_details=fail_value,
608706
trace_id=self.function_trace_id[:-4] + exp_type
609707
if self.experiment_id
610708
else self.function_trace_id,
611709
ai_service_client=ai_service_client,
612710
executor=self.executor,
613-
function_references=function_references,
711+
optimization_id=candidate.optimization_id,
614712
)
615713
)
616714
else:
@@ -745,6 +843,19 @@ def determine_best_candidate(
745843
if self.args.benchmark and benchmark_tree:
746844
console.print(benchmark_tree)
747845
console.rule()
846+
if candidate.optimization_id.endswith("cdrp"):
847+
log_code_repair_to_db(
848+
code_repair_log_db=self.code_repair_log_db,
849+
trace_id=self.function_trace_id[:-4] + exp_type,
850+
optimization_id=candidate.optimization_id,
851+
passed="yes" if is_correct[candidate.optimization_id] else "no",
852+
faster="yes"
853+
if (
854+
speedup_ratios[candidate.optimization_id] is not None
855+
and speedup_ratios[candidate.optimization_id] > 0
856+
)
857+
else "no",
858+
)
748859
except KeyboardInterrupt as e:
749860
logger.exception(f"Optimization interrupted: {e}")
750861
raise
@@ -869,12 +980,13 @@ def code_repair_optimizations(
869980
modified_source_code: str,
870981
test_details: str,
871982
trace_id: str,
983+
optimization_id: str,
872984
ai_service_client: AiServiceClient,
873985
executor: concurrent.futures.ThreadPoolExecutor,
874986
) -> concurrent.futures.Future:
875987
request = [
876988
AIServiceCodeRepairRequest(
877-
optimization_id="",
989+
optimization_id=optimization_id,
878990
original_source_code=original_source_code,
879991
modified_source_code=modified_source_code,
880992
test_details=test_details,
@@ -1875,7 +1987,7 @@ def run_optimized_candidate(
18751987
try:
18761988
diff_per_test_fn[diff.test_src_code] = (
18771989
diff_per_test_fn.setdefault(diff.test_src_code, "")
1878-
+ f"Expected Value: {diff.original_value!s}\nActual Value: {diff.candidate_value!s}\nError String:{diff.pytest_error}\n"
1990+
+ f"Expected Value: {diff.original_value!s}\nActual Value: {diff.candidate_value!s}\nError String:{diff.candidate_pytest_error}\n"
18791991
)
18801992

18811993
except Exception as e:

0 commit comments

Comments
 (0)