6262from codeflash .either import Failure , Success , is_successful
6363from codeflash .models .ExperimentMetadata import ExperimentMetadata
6464from codeflash .models .models import (
65+ LINE_SPLITTER_MARKER_PREFIX ,
6566 BestOptimization ,
6667 CodeOptimizationContext ,
6768 CodeStringsMarkdown ,
@@ -216,7 +217,7 @@ def generate_and_instrument_tests(
216217 revert_to_print = bool (get_pr_number ()),
217218 ):
218219 generated_results = self .generate_tests_and_optimizations (
219- testgen_context_code = code_context .testgen_context_code ,
220+ testgen_context_code = code_context .testgen_context_code , # TODO: should we send the markdow context for the testgen instead.
220221 read_writable_code = code_context .read_writable_code ,
221222 read_only_context_code = code_context .read_only_context_code ,
222223 helper_functions = code_context .helper_functions ,
@@ -289,7 +290,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
289290
290291 should_run_experiment , code_context , original_helper_code = initialization_result .unwrap ()
291292
292- code_print (code_context .read_writable_code .flat )
293+ code_print (code_context .read_writable_code .flat ) # Should we print the markdown or the flattened code?
293294
294295 test_setup_result = self .generate_and_instrument_tests ( # also generates optimizations
295296 code_context , should_run_experiment = should_run_experiment
@@ -414,11 +415,11 @@ def determine_best_candidate(
414415 get_run_tmp_file (Path (f"test_return_values_{ candidate_index } .bin" )).unlink (missing_ok = True )
415416 get_run_tmp_file (Path (f"test_return_values_{ candidate_index } .sqlite" )).unlink (missing_ok = True )
416417 logger .info (f"Optimization candidate { candidate_index } /{ original_len } :" )
417- code_print (candidate .source_code )
418+ code_print (candidate .source_code . flat )
418419 try :
419420 did_update = self .replace_function_and_helpers_with_optimized_code (
420421 code_context = code_context ,
421- optimized_code = candidate .source_code ,
422+ optimized_code = candidate .source_code . flat ,
422423 original_helper_code = original_helper_code ,
423424 )
424425 if not did_update :
@@ -578,7 +579,7 @@ def determine_best_candidate(
578579 runtimes_list = []
579580 for valid_opt in self .valid_optimizations :
580581 diff_lens_list .append (
581- diff_length (valid_opt .candidate .source_code , code_context .read_writable_code .flat )
582+ diff_length (valid_opt .candidate .source_code . flat , code_context .read_writable_code .flat )
582583 ) # char level diff
583584 runtimes_list .append (valid_opt .runtime )
584585 diff_lens_ranking = create_rank_dictionary_compact (diff_lens_list )
@@ -613,7 +614,7 @@ def refine_optimizations(
613614 original_source_code = code_context .read_writable_code .flat ,
614615 read_only_dependency_code = code_context .read_only_context_code ,
615616 original_code_runtime = humanize_runtime (original_code_baseline .runtime ),
616- optimized_source_code = opt .candidate .source_code ,
617+ optimized_source_code = opt .candidate .source_code . flat ,
617618 optimized_explanation = opt .candidate .explanation ,
618619 optimized_code_runtime = humanize_runtime (opt .runtime ),
619620 speedup = f"{ int (performance_gain (original_runtime_ns = original_code_baseline .runtime , optimized_runtime_ns = opt .runtime ) * 100 )} %" ,
@@ -679,13 +680,13 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path,
679680 f .write (helper_code )
680681
681682 def reformat_code_and_helpers (
682- self , helper_functions : list [FunctionSource ], path : Path , original_code : str , optimized_function : str
683+ self , helper_functions : list [FunctionSource ], path : Path , original_code : str , optimized_code : str
683684 ) -> tuple [str , dict [Path , str ]]:
684685 should_sort_imports = not self .args .disable_imports_sorting
685686 if should_sort_imports and isort .code (original_code ) != original_code :
686687 should_sort_imports = False
687688
688- new_code = format_code (self .args .formatter_cmds , path , optimized_function = optimized_function , check_diff = True )
689+ new_code = format_code (self .args .formatter_cmds , path , optimized_code = optimized_code , check_diff = True )
689690 if should_sort_imports :
690691 new_code = sort_imports (new_code )
691692
@@ -694,7 +695,7 @@ def reformat_code_and_helpers(
694695 module_abspath = hp .file_path
695696 hp_source_code = hp .source_code
696697 formatted_helper_code = format_code (
697- self .args .formatter_cmds , module_abspath , optimized_function = hp_source_code , check_diff = True
698+ self .args .formatter_cmds , module_abspath , optimized_code = hp_source_code , check_diff = True
698699 )
699700 if should_sort_imports :
700701 formatted_helper_code = sort_imports (formatted_helper_code )
@@ -711,7 +712,8 @@ def replace_function_and_helpers_with_optimized_code(
711712 self .function_to_optimize .qualified_name
712713 )
713714
714- file_to_code_context = CodeStringsMarkdown .parse_splitter_markers (optimized_code )
715+ code_strings = CodeStringsMarkdown .parse_splitter_markers (optimized_code ).code_strings
716+ file_to_code_context = {str (code_string .file_path ): code_string .code for code_string in code_strings }
715717
716718 for helper_function in code_context .helper_functions :
717719 if helper_function .jedi_definition .type != "class" :
@@ -721,11 +723,12 @@ def replace_function_and_helpers_with_optimized_code(
721723 relative_module_path = str (module_abspath .relative_to (self .project_root ))
722724 logger .debug (f"applying optimized code to: { relative_module_path } " )
723725
724- scoped_optimized_code = file_to_code_context .get (relative_module_path , None )
726+ scoped_optimized_code = file_to_code_context .get (relative_module_path )
725727 if scoped_optimized_code is None :
726728 logger .warning (
727729 f"Optimized code not found for { relative_module_path } In the context\n -------\n { optimized_code } \n -------\n "
728730 "Existing files in the context are: {list(file_to_code_context.keys())}, re-check your 'split markers'"
731+ f"existing files are { file_to_code_context .keys ()} "
729732 )
730733 scoped_optimized_code = ""
731734
@@ -1063,7 +1066,7 @@ def find_and_process_best_optimization(
10631066
10641067 if best_optimization :
10651068 logger .info ("Best candidate:" )
1066- code_print (best_optimization .candidate .source_code )
1069+ code_print (best_optimization .candidate .source_code . flat )
10671070 console .print (
10681071 Panel (
10691072 best_optimization .candidate .explanation , title = "Best Candidate Explanation" , border_style = "blue"
@@ -1089,15 +1092,15 @@ def find_and_process_best_optimization(
10891092
10901093 self .replace_function_and_helpers_with_optimized_code (
10911094 code_context = code_context ,
1092- optimized_code = best_optimization .candidate .source_code ,
1095+ optimized_code = best_optimization .candidate .source_code . flat ,
10931096 original_helper_code = original_helper_code ,
10941097 )
10951098
10961099 new_code , new_helper_code = self .reformat_code_and_helpers (
10971100 code_context .helper_functions ,
10981101 explanation .file_path ,
10991102 self .function_to_optimize_source_code ,
1100- optimized_function = best_optimization .candidate .source_code ,
1103+ optimized_code = best_optimization .candidate .source_code . flat ,
11011104 )
11021105
11031106 original_code_combined = original_helper_code .copy ()
@@ -1169,10 +1172,14 @@ def process_review(
11691172 optimized_runtimes_all = optimized_runtime_by_test ,
11701173 )
11711174 new_explanation_raw_str = self .aiservice_client .get_new_explanation (
1172- source_code = code_context .read_writable_code ,
1175+ source_code = code_context .read_writable_code .flat .replace (
1176+ LINE_SPLITTER_MARKER_PREFIX , "# file: "
1177+ ), # for better readability to the LLM
11731178 dependency_code = code_context .read_only_context_code ,
11741179 trace_id = self .function_trace_id [:- 4 ] + exp_type if self .experiment_id else self .function_trace_id ,
1175- optimized_code = best_optimization .candidate .source_code ,
1180+ optimized_code = best_optimization .candidate .source_code .flat .replace (
1181+ LINE_SPLITTER_MARKER_PREFIX , "# file: "
1182+ ),
11761183 original_line_profiler_results = original_code_baseline .line_profile_results ["str_out" ],
11771184 optimized_line_profiler_results = best_optimization .line_profiler_test_results ["str_out" ],
11781185 original_code_runtime = humanize_runtime (original_code_baseline .runtime ),
0 commit comments