Skip to content

Commit 6dd72cf

Browse files
committed
Correct calculation of all preexisting "function" symbols for formatting purposes
This is done on a per-path basis.
1 parent c3b8063 commit 6dd72cf

File tree

2 files changed

+30
-13
lines changed

2 files changed

+30
-13
lines changed

codeflash/code_utils/formatter.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from codeflash.cli_cmds.console import console, logger
1212
from codeflash.code_utils.code_replacer import OptimFunctionCollector
1313
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
14-
from codeflash.models.models import CodeOptimizationContext
14+
from codeflash.models.models import FunctionParent, FunctionSource
1515

1616
if TYPE_CHECKING:
1717
from pathlib import Path
@@ -64,14 +64,15 @@ def sort_imports(code: str) -> str:
6464
def get_modification_code_ranges(
6565
modified_code: str,
6666
fto: FunctionToOptimize,
67-
code_context: CodeOptimizationContext,
67+
preexisting_functions: set[tuple[str, tuple[FunctionParent,...]]],
68+
helper_functions: list[FunctionSource],
6869
) -> list[tuple[int, int]]:
6970
"""
7071
Returns the line number of modified and new functions in a string containing containing the code in a fully modified file.
7172
"""
7273
modified_functions = set()
7374
modified_functions.add(fto.qualified_name)
74-
for helper_function in code_context.helper_functions:
75+
for helper_function in helper_functions:
7576
if helper_function.jedi_definition.type != "class":
7677
modified_functions.add(helper_function.qualified_name)
7778

@@ -88,6 +89,6 @@ def get_modification_code_ranges(
8889
parsed_function_names.add((class_name, function_name))
8990

9091
module = cst.metadata.MetadataWrapper(cst.parse_module(modified_code))
91-
visitor = OptimFunctionCollector(code_context.preexisting_objects, parsed_function_names)
92+
visitor = OptimFunctionCollector(preexisting_functions, parsed_function_names)
9293
module.visit(visitor)
9394
return visitor.modification_code_range_lines

codeflash/optimization/function_optimizer.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from codeflash.benchmarking.utils import process_benchmark_data
2323
from codeflash.cli_cmds.console import code_print, console, logger, progress_bar
2424
from codeflash.code_utils import env_utils
25+
from codeflash.code_utils.code_extractor import find_preexisting_objects
2526
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
2627
from codeflash.code_utils.code_utils import (
2728
cleanup_paths,
@@ -49,6 +50,7 @@
4950
BestOptimization,
5051
CodeOptimizationContext,
5152
FunctionCalledInTest,
53+
FunctionParent,
5254
GeneratedTests,
5355
GeneratedTestsList,
5456
OptimizationSet,
@@ -297,12 +299,20 @@ def optimize_function(self) -> Result[BestOptimization, str]:
297299

298300
self.log_successful_optimization(explanation, generated_tests, exp_type)
299301

302+
# xylophone
303+
preexisting_functions_by_filepath: dict[Path, list[str]] = {}
304+
filepaths_to_inspect = [self.function_to_optimize.file_path] + list({helper.file_path for helper in code_context.helper_functions})
305+
for filepath in filepaths_to_inspect:
306+
source_code = filepath.read_text(encoding="utf8")
307+
preexisting_functions_by_filepath[filepath] = find_preexisting_objects(source_code)
308+
300309
self.replace_function_and_helpers_with_optimized_code(
301310
code_context=code_context, optimized_code=best_optimization.candidate.source_code
302311
)
303312

304313
new_code, new_helper_code = self.reformat_code_and_helpers(
305-
code_context,
314+
preexisting_functions_by_filepath,
315+
code_context.helper_functions,
306316
explanation.file_path,
307317
self.function_to_optimize_source_code,
308318
)
@@ -584,29 +594,35 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path,
584594

585595
def reformat_code_and_helpers(
586596
self,
587-
code_context: CodeOptimizationContext,
597+
preexisting_functions_by_filepath: dict[Path, set[tuple[str, tuple[FunctionParent,...]]]],
598+
helper_functions: list[FunctionSource],
588599
fto_path: Path,
589600
original_code: str,
590601
) -> tuple[str, dict[Path, str]]:
591602
should_sort_imports = not self.args.disable_imports_sorting
592603
if should_sort_imports and isort.code(original_code) != original_code:
593604
should_sort_imports = False
594605

595-
helper_functions = code_context.helper_functions
596-
597606
paths = [fto_path] + list({hf.file_path for hf in helper_functions})
598607
new_target_code = None
599608
new_helper_code: dict[Path, str] = {}
600609
for i, path in enumerate(paths):
601610
unformatted_code = path.read_text(encoding="utf8")
602-
# TODO(zomglings): code_context.preexisting_objects doesn't read all functions in the old file. We should add that to context
603-
# separately. That's a much bigger change.
604-
code_ranges_unformatted = get_modification_code_ranges(unformatted_code, self.function_to_optimize, code_context)
605-
611+
code_ranges_unformatted = get_modification_code_ranges(
612+
unformatted_code,
613+
self.function_to_optimize,
614+
preexisting_functions_by_filepath[path],
615+
helper_functions,
616+
)
606617
formatted_code = format_code(self.args.formatter_cmds, path)
607618
# Note: We do not need to refresh the code_context because we only use it to refer to names of original
608619
# functions (even before optimization was applied) and filepaths, none of which is changing.
609-
code_ranges_formatted = get_modification_code_ranges(formatted_code, self.function_to_optimize, code_context)
620+
code_ranges_formatted = get_modification_code_ranges(
621+
formatted_code,
622+
self.function_to_optimize,
623+
preexisting_functions_by_filepath[path],
624+
helper_functions,
625+
)
610626

611627
if len(code_ranges_formatted) != len(code_ranges_unformatted):
612628
raise Exception("Formatting had unexpected effects on code ranges")

0 commit comments

Comments
 (0)