|
116 | 116 | CoverageData, |
117 | 117 | FunctionCalledInTest, |
118 | 118 | FunctionSource, |
| 119 | + TestDiff, |
119 | 120 | ) |
120 | 121 | from codeflash.verification.verification_utils import TestConfig |
121 | 122 |
|
@@ -685,32 +686,15 @@ def determine_best_candidate( |
685 | 686 | baseline_results=original_code_baseline, |
686 | 687 | original_helper_code=original_helper_code, |
687 | 688 | file_path_to_helper_classes=file_path_to_helper_classes, |
| 689 | + code_context=code_context, |
| 690 | + candidate=candidate, |
| 691 | + exp_type=exp_type, |
688 | 692 | ) |
689 | 693 | console.rule() |
690 | 694 | if not is_successful(run_results): |
691 | 695 | optimized_runtimes[candidate.optimization_id] = None |
692 | 696 | is_correct[candidate.optimization_id] = False |
693 | 697 | speedup_ratios[candidate.optimization_id] = None |
694 | | - fail_value = run_results.value |
695 | | - if ( |
696 | | - fail_value.strip() != "Test results did not match the test results of the original code." |
697 | | - and len(future_all_refinements) <= 3 |
698 | | - and not candidate.optimization_id.endswith("cdrp") |
699 | | - ): |
700 | | - # # queue corresponding code repair optimization for best optimization |
701 | | - future_all_refinements.append( |
702 | | - self.code_repair_optimizations( |
703 | | - original_source_code=code_context.read_writable_code.markdown, |
704 | | - modified_source_code=candidate.source_code.markdown, |
705 | | - test_details=fail_value, |
706 | | - trace_id=self.function_trace_id[:-4] + exp_type |
707 | | - if self.experiment_id |
708 | | - else self.function_trace_id, |
709 | | - ai_service_client=ai_service_client, |
710 | | - executor=self.executor, |
711 | | - optimization_id=candidate.optimization_id, |
712 | | - ) |
713 | | - ) |
714 | 698 | else: |
715 | 699 | candidate_result: OptimizedCandidateResult = run_results.unwrap() |
716 | 700 | best_test_runtime = candidate_result.best_test_runtime |
@@ -978,22 +962,19 @@ def code_repair_optimizations( |
978 | 962 | self, |
979 | 963 | original_source_code: str, |
980 | 964 | modified_source_code: str, |
981 | | - test_details: str, |
| 965 | + test_diffs: list[TestDiff], |
982 | 966 | trace_id: str, |
983 | 967 | optimization_id: str, |
984 | 968 | ai_service_client: AiServiceClient, |
985 | | - executor: concurrent.futures.ThreadPoolExecutor, |
986 | | - ) -> concurrent.futures.Future: |
987 | | - request = [ |
988 | | - AIServiceCodeRepairRequest( |
989 | | - optimization_id=optimization_id, |
990 | | - original_source_code=original_source_code, |
991 | | - modified_source_code=modified_source_code, |
992 | | - test_details=test_details, |
993 | | - trace_id=trace_id, |
994 | | - ) |
995 | | - ] |
996 | | - return executor.submit(ai_service_client.optimize_python_code_repair, request=request) |
| 969 | + ) -> OptimizedCandidate | None: |
| 970 | + request = AIServiceCodeRepairRequest( |
| 971 | + optimization_id=optimization_id, |
| 972 | + original_source_code=original_source_code, |
| 973 | + modified_source_code=modified_source_code, |
| 974 | + test_diffs=test_diffs, |
| 975 | + trace_id=trace_id, |
| 976 | + ) |
| 977 | + return ai_service_client.optimize_python_code_repair(request=request) |
997 | 978 |
|
998 | 979 | def log_successful_optimization( |
999 | 980 | self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str |
@@ -1920,6 +1901,9 @@ def run_optimized_candidate( |
1920 | 1901 | baseline_results: OriginalCodeBaseline, |
1921 | 1902 | original_helper_code: dict[Path, str], |
1922 | 1903 | file_path_to_helper_classes: dict[Path, set[str]], |
| 1904 | + code_context: CodeOptimizationContext, |
| 1905 | + candidate: OptimizedCandidate, |
| 1906 | + exp_type: str, |
1923 | 1907 | ) -> Result[OptimizedCandidateResult, str]: |
1924 | 1908 | assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} # noqa: RUF018 |
1925 | 1909 |
|
@@ -1980,29 +1964,50 @@ def run_optimized_candidate( |
1980 | 1964 | # if the test unmatched percentage is greater than 50%, we can't fix it |
1981 | 1965 | return self.get_results_not_matched_error() |
1982 | 1966 |
|
1983 | | - logger.info("running code repair...") |
1984 | | - # not sure if all return types will be convertible to string |
1985 | | - diff_per_test_fn = {} |
1986 | | - for diff in diffs: |
1987 | | - try: |
1988 | | - diff_per_test_fn[diff.test_src_code] = ( |
1989 | | - diff_per_test_fn.setdefault(diff.test_src_code, "") |
1990 | | - + f"Expected Value: {diff.original_value!s}\nActual Value: {diff.candidate_value!s}\nError String:{diff.candidate_pytest_error}\n" |
1991 | | - ) |
| 1967 | + if candidate.optimization_id.endswith("cdrp"): |
| 1968 | + # prevent looping for now |
| 1969 | + return self.get_results_not_matched_error() |
| 1970 | + |
| 1971 | + ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client |
| 1972 | + |
| 1973 | + with progress_bar("The test results are not matching, let me see if I can fix this"): |
| 1974 | + new_candidate = self.code_repair_optimizations( |
| 1975 | + original_source_code=code_context.read_writable_code.markdown, |
| 1976 | + modified_source_code=candidate.source_code.markdown, |
| 1977 | + test_diffs=diffs, |
| 1978 | + trace_id=self.function_trace_id[:-4] + exp_type |
| 1979 | + if self.experiment_id |
| 1980 | + else self.function_trace_id, |
| 1981 | + ai_service_client=ai_service_client, |
| 1982 | + optimization_id=candidate.optimization_id, |
| 1983 | + ) |
| 1984 | + if not new_candidate: |
| 1985 | + return Failure("Code repair failed to generate a valid candidate.") |
| 1986 | + |
| 1987 | + code_print(new_candidate.source_code.flat) |
1992 | 1988 |
|
1993 | | - except Exception as e: |
1994 | | - sentry_sdk.capture_exception(e) |
1995 | | - logger.exception(e) |
1996 | | - return self.get_results_not_matched_error() |
1997 | 1989 | try: |
1998 | | - test_issues = "\n".join( |
1999 | | - f"{test_fn_def}\n{value}" for test_fn_def, value in diff_per_test_fn.items() |
| 1990 | + did_update = self.replace_function_and_helpers_with_optimized_code( |
| 1991 | + code_context=code_context, |
| 1992 | + optimized_code=new_candidate.source_code, |
| 1993 | + original_helper_code=original_helper_code, |
2000 | 1994 | ) |
2001 | | - except Exception as e: |
2002 | | - sentry_sdk.capture_exception(e) |
2003 | | - logger.exception(e) |
2004 | | - return self.get_results_not_matched_error() |
2005 | | - return Failure(test_issues) |
| 1995 | + if did_update: |
| 1996 | + return self.run_optimized_candidate( |
| 1997 | + optimization_candidate_index=optimization_candidate_index, |
| 1998 | + baseline_results=baseline_results, |
| 1999 | + original_helper_code=original_helper_code, |
| 2000 | + file_path_to_helper_classes=file_path_to_helper_classes, |
| 2001 | + code_context=code_context, |
| 2002 | + candidate=new_candidate, |
| 2003 | + exp_type=exp_type, |
| 2004 | + ) |
| 2005 | + except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e: |
| 2006 | + logger.error(e) |
| 2007 | + self.write_code_and_helpers( |
| 2008 | + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path |
| 2009 | + ) |
| 2010 | + return Failure("Code repair failed to generate a valid candidate.") |
2006 | 2011 |
|
2007 | 2012 | logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...") |
2008 | 2013 |
|
|
0 commit comments