2121from codeflash .api .aiservice import AiServiceClient , LocalAiServiceClient
2222from codeflash .cli_cmds .console import code_print , console , logger , progress_bar
2323from codeflash .code_utils import env_utils
24- from codeflash .code_utils .code_extractor import add_needed_imports_from_module , extract_code
2524from codeflash .code_utils .code_replacer import replace_function_definitions_in_module , add_decorator_imports
2625from codeflash .code_utils .code_utils import (
2726 cleanup_paths ,
@@ -208,7 +207,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
208207 and "." in function_source .qualified_name
209208 ):
210209 file_path_to_helper_classes [function_source .file_path ].add (function_source .qualified_name .split ("." )[0 ])
211-
210+ pass
212211 baseline_result = self .establish_original_code_baseline ( # this needs better typing
213212 code_context = code_context ,
214213 original_helper_code = original_helper_code ,
@@ -232,7 +231,29 @@ def optimize_function(self) -> Result[BestOptimization, str]:
232231 return Failure ("The threshold for test coverage was not met." )
233232
234233 best_optimization = None
235-
234+ logger .info (f"Adding more candidates based on lineprof info, calling ai service" )
235+ with progress_bar (
236+ f"Generating new optimizations for function { self .function_to_optimize .function_name } with line profiler information" ,
237+ transient = True ,
238+ ):
239+ pass
240+ lprof_generated_results = self .aiservice_client .optimize_python_code_line_profiler (
241+
242+ source_code = code_context .read_writable_code ,
243+ dependency_code = code_context .read_only_context_code ,
244+ trace_id = self .function_trace_id ,
245+ line_profiler_results = original_code_baseline .lprof_results ,
246+ num_candidates = 10 ,
247+ experiment_metadata = None )
248+
249+ if len (lprof_generated_results )== 0 :
250+ logger .info (f"Generated tests with line profiler failed." )
251+ else :
252+ logger .info (f"Generated tests with line profiler succeeded. Appending to optimization candidates." )
253+ print ("initial optimization candidates" ,len (optimizations_set .control ))
254+ optimizations_set .control .extend (lprof_generated_results )
255+ print ("after adding optimization candidates" ,len (optimizations_set .control ))
256+ #append to optimization candidates
236257 for _u , candidates in enumerate ([optimizations_set .control , optimizations_set .experiment ]):
237258 if candidates is None :
238259 continue
@@ -813,7 +834,7 @@ def establish_original_code_baseline(
813834 files_to_instrument .append (helper_obj .file_path )
814835 fns_to_instrument .append (helper_obj .qualified_name )
815836 add_decorator_imports (files_to_instrument ,fns_to_instrument , lprofiler_database_file )
816- behavioral_results , coverage_results = self .run_and_parse_tests (
837+ lprof_results , _ = self .run_and_parse_tests (
817838 testing_type = TestingMode .BEHAVIOR ,
818839 test_env = test_env ,
819840 test_files = self .test_files ,
@@ -822,7 +843,9 @@ def establish_original_code_baseline(
822843 enable_coverage = False ,
823844 enable_lprofiler = test_framework == "pytest" ,
824845 code_context = code_context ,
825- )
846+ lprofiler_database_file = lprofiler_database_file ,
847+ )
848+ pass
826849 except Exception as e :
827850 logger .warning (f"Failed to run lprof for { self .function_to_optimize .function_name } . SKIPPING OPTIMIZING THIS FUNCTION." )
828851 console .rule ()
@@ -905,6 +928,7 @@ def establish_original_code_baseline(
905928 benchmarking_test_results = benchmarking_results ,
906929 runtime = total_timing ,
907930 coverage_results = coverage_results ,
931+ lprof_results = lprof_results ,
908932 ),
909933 functions_to_remove ,
910934 )
@@ -1041,6 +1065,7 @@ def run_and_parse_tests(
10411065 pytest_max_loops : int = 100_000 ,
10421066 code_context : CodeOptimizationContext | None = None ,
10431067 unittest_loop_index : int | None = None ,
1068+ lprofiler_database_file : str | None = None ,
10441069 ) -> tuple [TestResults , CoverageData | None ]:
10451070 coverage_database_file = None
10461071 coverage_config_file = None
@@ -1083,20 +1108,24 @@ def run_and_parse_tests(
10831108 f"stdout: { run_result .stdout } \n "
10841109 f"stderr: { run_result .stderr } \n "
10851110 )
1086-
1087- results , coverage_results = parse_test_results (
1088- test_xml_path = result_file_path ,
1089- test_files = test_files ,
1090- test_config = self .test_cfg ,
1091- optimization_iteration = optimization_iteration ,
1092- run_result = run_result ,
1093- unittest_loop_index = unittest_loop_index ,
1094- function_name = self .function_to_optimize .function_name ,
1095- source_file = self .function_to_optimize .file_path ,
1096- code_context = code_context ,
1097- coverage_database_file = coverage_database_file ,
1098- coverage_config_file = coverage_config_file ,
1099- )
1111+ if not enable_lprofiler :
1112+ results , coverage_results = parse_test_results (
1113+ test_xml_path = result_file_path ,
1114+ test_files = test_files ,
1115+ test_config = self .test_cfg ,
1116+ optimization_iteration = optimization_iteration ,
1117+ run_result = run_result ,
1118+ unittest_loop_index = unittest_loop_index ,
1119+ function_name = self .function_to_optimize .function_name ,
1120+ source_file = self .function_to_optimize .file_path ,
1121+ code_context = code_context ,
1122+ coverage_database_file = coverage_database_file ,
1123+ coverage_config_file = coverage_config_file ,
1124+ )
1125+ else :
1126+ pass
1127+ file_contents = Path (str (lprofiler_database_file )+ ".txt" ).read_text ("utf-8" )
1128+ return file_contents , None
11001129 return results , coverage_results
11011130
11021131 def generate_and_instrument_tests (
0 commit comments