|
9 | 9 | import uuid |
10 | 10 | from collections import defaultdict, deque |
11 | 11 | from pathlib import Path |
12 | | -from typing import cast, TYPE_CHECKING |
| 12 | +from typing import Optional, TYPE_CHECKING |
13 | 13 |
|
14 | 14 | import isort |
15 | 15 | import libcst as cst |
@@ -302,7 +302,9 @@ def optimize_function(self) -> Result[BestOptimization, str]: |
302 | 302 | ) |
303 | 303 |
|
304 | 304 | new_code, new_helper_code = self.reformat_code_and_helpers( |
305 | | - code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code |
| 305 | + code_context, |
| 306 | + explanation.file_path, |
| 307 | + self.function_to_optimize_source_code, |
306 | 308 | ) |
307 | 309 |
|
308 | 310 | existing_tests = existing_tests_source_for( |
@@ -581,21 +583,24 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path, |
581 | 583 | f.write(original_helper_code[module_abspath]) |
582 | 584 |
|
583 | 585 | def reformat_code_and_helpers( |
584 | | - self, helper_functions: list[FunctionSource], fto_path: Path, original_code: str |
| 586 | + self, |
| 587 | + code_context: CodeOptimizationContext, |
| 588 | + fto_path: Path, |
| 589 | + original_code: str, |
585 | 590 | ) -> tuple[str, dict[Path, str]]: |
586 | 591 | should_sort_imports = not self.args.disable_imports_sorting |
587 | 592 | if should_sort_imports and isort.code(original_code) != original_code: |
588 | 593 | should_sort_imports = False |
589 | 594 |
|
| 595 | + helper_functions = code_context.helper_functions |
| 596 | + |
590 | 597 | paths = [fto_path] + list({hf.file_path for hf in helper_functions}) |
591 | 598 | new_target_code = None |
592 | 599 | new_helper_code: dict[Path, str] = {} |
593 | 600 | for i, path in enumerate(paths): |
594 | 601 | unformatted_code = path.read_text(encoding="utf8") |
595 | | - code_context_result = self.get_code_optimization_context() |
596 | | - if code_context_result.is_failure(): |
597 | | - raise Exception("Unable to generate code context for formatting purposes") |
598 | | - code_context = cast(CodeOptimizationContext, code_context_result.unwrap()) |
| 602 | + # TODO(zomglings): code_context.preexisting_objects doesn't read all functions in the old file. We should add that to context |
| 603 | + # separately. That's a much bigger change. |
599 | 604 | code_ranges_unformatted = get_modification_code_ranges(unformatted_code, self.function_to_optimize, code_context) |
600 | 605 |
|
601 | 606 | formatted_code = format_code(self.args.formatter_cmds, path) |
|
0 commit comments