1313from typing import TYPE_CHECKING
1414
1515import libcst as cst
16+ import sentry_sdk
1617from rich .console import Group
1718from rich .panel import Panel
1819from rich .syntax import Syntax
6970from codeflash .lsp .lsp_message import LspCodeMessage , LspMarkdownMessage , LSPMessageId
7071from codeflash .models .ExperimentMetadata import ExperimentMetadata
7172from codeflash .models .models import (
73+ AIServiceCodeRepairRequest ,
7274 BestOptimization ,
7375 CodeOptimizationContext ,
7476 GeneratedTests ,
@@ -589,6 +591,28 @@ def determine_best_candidate(
589591 optimized_runtimes [candidate .optimization_id ] = None
590592 is_correct [candidate .optimization_id ] = False
591593 speedup_ratios [candidate .optimization_id ] = None
594+ fail_value = run_results .value
595+ if (
596+ fail_value != "Test results did not match the test results of the original code."
597+ and len (future_all_refinements ) <= 3
598+ and not candidate .optimization_id .endswith ("cdrp" )
599+ ):
600+ # # queue corresponding code repair optimization for best optimization
601+ future_all_refinements .append (
602+ self .code_repair_optimizations (
603+ original_source_code = candidate ,
604+ modified_source_code = code_context ,
605+ original_code_baseline = original_code_baseline ,
606+ test_details = "test_details" ,
607+ code_context = code_context ,
608+ trace_id = self .function_trace_id [:- 4 ] + exp_type
609+ if self .experiment_id
610+ else self .function_trace_id ,
611+ ai_service_client = ai_service_client ,
612+ executor = self .executor ,
613+ function_references = function_references ,
614+ )
615+ )
592616 else :
593617 candidate_result : OptimizedCandidateResult = run_results .unwrap ()
594618 best_test_runtime = candidate_result .best_test_runtime
@@ -672,21 +696,21 @@ def determine_best_candidate(
672696 async_throughput = candidate_result .async_throughput ,
673697 )
674698 valid_optimizations .append (best_optimization )
675- # queue corresponding refined optimization for best optimization
676- if not candidate .optimization_id .endswith ("refi" ):
677- future_all_refinements .append (
678- self .refine_optimizations (
679- valid_optimizations = [best_optimization ],
680- original_code_baseline = original_code_baseline ,
681- code_context = code_context ,
682- trace_id = self .function_trace_id [:- 4 ] + exp_type
683- if self .experiment_id
684- else self .function_trace_id ,
685- ai_service_client = ai_service_client ,
686- executor = self .executor ,
687- function_references = function_references ,
688- )
689- )
699+ # # queue corresponding refined optimization for best optimization
700+ # if not candidate.optimization_id.endswith("refi"):
701+ # future_all_refinements.append(
702+ # self.refine_optimizations(
703+ # valid_optimizations=[best_optimization],
704+ # original_code_baseline=original_code_baseline,
705+ # code_context=code_context,
706+ # trace_id=self.function_trace_id[:-4] + exp_type
707+ # if self.experiment_id
708+ # else self.function_trace_id,
709+ # ai_service_client=ai_service_client,
710+ # executor=self.executor,
711+ # function_references=function_references,
712+ # )
713+ # )
690714 else :
691715 # For async functions, prioritize throughput metrics over runtime even for slow candidates
692716 is_async = (
@@ -839,6 +863,26 @@ def refine_optimizations(
839863 ]
840864 return executor .submit (ai_service_client .optimize_python_code_refinement , request = request )
841865
866+ def code_repair_optimizations (
867+ self ,
868+ original_source_code : str ,
869+ modified_source_code : str ,
870+ test_details : str ,
871+ trace_id : str ,
872+ ai_service_client : AiServiceClient ,
873+ executor : concurrent .futures .ThreadPoolExecutor ,
874+ ) -> concurrent .futures .Future :
875+ request = [
876+ AIServiceCodeRepairRequest (
877+ optimization_id = "" ,
878+ original_source_code = original_source_code ,
879+ modified_source_code = modified_source_code ,
880+ test_details = test_details ,
881+ trace_id = trace_id ,
882+ )
883+ ]
884+ return executor .submit (ai_service_client .optimize_python_code_repair , request = request )
885+
842886 def log_successful_optimization (
843887 self , explanation : Explanation , generated_tests : GeneratedTestsList , exp_type : str
844888 ) -> None :
@@ -1813,6 +1857,7 @@ def run_optimized_candidate(
18131857 )
18141858 )
18151859 console .rule ()
1860+ # print(type(code_context), type(candidate))
18161861 match , diffs = compare_test_results (baseline_results .behavior_test_results , candidate_behavior_results )
18171862 if match :
18181863 logger .info ("h3|Test results matched ✅" )
@@ -1823,15 +1868,29 @@ def run_optimized_candidate(
18231868 # if the test unmatched percentage is greater than 50%, we can't fix it
18241869 return self .get_results_not_matched_error ()
18251870
1826- print (f"should try to fix it, diffs: { diffs } " )
1827- # with the parsed test results diff ask the llm to fix the candidate to match the test results of the original code, and run again
1828- # self.run_optimized_candidate(
1829- # optimization_candidate_index=optimization_candidate_index,
1830- # baseline_results=baseline_results,
1831- # original_helper_code=original_helper_code,
1832- # file_path_to_helper_classes=file_path_to_helper_classes,
1833- # )
1834- return self .get_results_not_matched_error ()
1871+ logger .info ("running code repair..." )
1872+ # not sure if all return types will be convertible to string
1873+ diff_per_test_fn = {}
1874+ for diff in diffs :
1875+ try :
1876+ diff_per_test_fn [diff .test_src_code ] = (
1877+ diff_per_test_fn .setdefault (diff .test_src_code , "" )
1878+ + f"Expected Value: { diff .original_value !s} \n Actual Value: { diff .candidate_value !s} \n Error String:{ diff .pytest_error } \n "
1879+ )
1880+
1881+ except Exception as e :
1882+ sentry_sdk .capture_exception (e )
1883+ logger .exception (e )
1884+ return self .get_results_not_matched_error ()
1885+ try :
1886+ test_issues = "\n " .join (
1887+ f"{ test_fn_def } \n { value } " for test_fn_def , value in diff_per_test_fn .items ()
1888+ )
1889+ except Exception as e :
1890+ sentry_sdk .capture_exception (e )
1891+ logger .exception (e )
1892+ return self .get_results_not_matched_error ()
1893+ return Failure (test_issues )
18351894
18361895 logger .info (f"loading|Running performance tests for candidate { optimization_candidate_index } ..." )
18371896
0 commit comments