|
22 | 22 | from codeflash.benchmarking.utils import process_benchmark_data |
23 | 23 | from codeflash.cli_cmds.console import code_print, console, logger, progress_bar |
24 | 24 | from codeflash.code_utils import env_utils |
| 25 | +from codeflash.code_utils.code_extractor import find_preexisting_objects |
25 | 26 | from codeflash.code_utils.code_replacer import replace_function_definitions_in_module |
26 | 27 | from codeflash.code_utils.code_utils import ( |
27 | 28 | cleanup_paths, |
|
49 | 50 | BestOptimization, |
50 | 51 | CodeOptimizationContext, |
51 | 52 | FunctionCalledInTest, |
| 53 | + FunctionParent, |
52 | 54 | GeneratedTests, |
53 | 55 | GeneratedTestsList, |
54 | 56 | OptimizationSet, |
@@ -297,12 +299,20 @@ def optimize_function(self) -> Result[BestOptimization, str]: |
297 | 299 |
|
298 | 300 | self.log_successful_optimization(explanation, generated_tests, exp_type) |
299 | 301 |
|
| 302 | + # xylophone |
| 303 | + preexisting_functions_by_filepath: dict[Path, list[str]] = {} |
| 304 | + filepaths_to_inspect = [self.function_to_optimize.file_path] + list({helper.file_path for helper in code_context.helper_functions}) |
| 305 | + for filepath in filepaths_to_inspect: |
| 306 | + source_code = filepath.read_text(encoding="utf8") |
| 307 | + preexisting_functions_by_filepath[filepath] = find_preexisting_objects(source_code) |
| 308 | + |
300 | 309 | self.replace_function_and_helpers_with_optimized_code( |
301 | 310 | code_context=code_context, optimized_code=best_optimization.candidate.source_code |
302 | 311 | ) |
303 | 312 |
|
304 | 313 | new_code, new_helper_code = self.reformat_code_and_helpers( |
305 | | - code_context, |
| 314 | + preexisting_functions_by_filepath, |
| 315 | + code_context.helper_functions, |
306 | 316 | explanation.file_path, |
307 | 317 | self.function_to_optimize_source_code, |
308 | 318 | ) |
@@ -584,29 +594,35 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path, |
584 | 594 |
|
585 | 595 | def reformat_code_and_helpers( |
586 | 596 | self, |
587 | | - code_context: CodeOptimizationContext, |
| 597 | + preexisting_functions_by_filepath: dict[Path, set[tuple[str, tuple[FunctionParent,...]]]], |
| 598 | + helper_functions: list[FunctionSource], |
588 | 599 | fto_path: Path, |
589 | 600 | original_code: str, |
590 | 601 | ) -> tuple[str, dict[Path, str]]: |
591 | 602 | should_sort_imports = not self.args.disable_imports_sorting |
592 | 603 | if should_sort_imports and isort.code(original_code) != original_code: |
593 | 604 | should_sort_imports = False |
594 | 605 |
|
595 | | - helper_functions = code_context.helper_functions |
596 | | - |
597 | 606 | paths = [fto_path] + list({hf.file_path for hf in helper_functions}) |
598 | 607 | new_target_code = None |
599 | 608 | new_helper_code: dict[Path, str] = {} |
600 | 609 | for i, path in enumerate(paths): |
601 | 610 | unformatted_code = path.read_text(encoding="utf8") |
602 | | - # TODO(zomglings): code_context.preexisting_objects doesn't read all functions in the old file. We should add that to context |
603 | | - # separately. That's a much bigger change. |
604 | | - code_ranges_unformatted = get_modification_code_ranges(unformatted_code, self.function_to_optimize, code_context) |
605 | | - |
| 611 | + code_ranges_unformatted = get_modification_code_ranges( |
| 612 | + unformatted_code, |
| 613 | + self.function_to_optimize, |
| 614 | + preexisting_functions_by_filepath[path], |
| 615 | + helper_functions, |
| 616 | + ) |
606 | 617 | formatted_code = format_code(self.args.formatter_cmds, path) |
607 | 618 | # Note: We do not need to refresh the code_context because we only use it to refer to names of original |
608 | 619 | # functions (even before optimization was applied) and filepaths, none of which is changing. |
609 | | - code_ranges_formatted = get_modification_code_ranges(formatted_code, self.function_to_optimize, code_context) |
| 620 | + code_ranges_formatted = get_modification_code_ranges( |
| 621 | + formatted_code, |
| 622 | + self.function_to_optimize, |
| 623 | + preexisting_functions_by_filepath[path], |
| 624 | + helper_functions, |
| 625 | + ) |
610 | 626 |
|
611 | 627 | if len(code_ranges_formatted) != len(code_ranges_unformatted): |
612 | 628 | raise Exception("Formatting had unexpected effects on code ranges") |
|
0 commit comments