Skip to content

Commit 82b9d41

Browse files
committed
Cleaned up the import sorting code in FunctionOptimizer
1 parent 2c40018 commit 82b9d41

File tree

2 files changed

+12
-15
lines changed

2 files changed

+12
-15
lines changed

codeflash/code_utils/formatter.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,12 @@ def sort_imports(code: str) -> str:
5555
return code # Fall back to original code if isort fails
5656

5757
return sorted_code
58+
59+
60+
def sort_imports_in_place(paths: list[Path]) -> None:
61+
for path in paths:
62+
if path.exists():
63+
content = path.read_text(encoding="utf8")
64+
sorted_content = sort_imports(content)
65+
if sorted_content != content:
66+
path.write_text(sorted_content, encoding="utf8")

codeflash/optimization/function_optimizer.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
N_TESTS_TO_GENERATE,
3939
TOTAL_LOOPING_TIME,
4040
)
41-
from codeflash.code_utils.formatter import format_code, sort_imports
41+
from codeflash.code_utils.formatter import format_code, sort_imports_in_place
4242
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
4343
from codeflash.code_utils.line_profile_utils import add_decorator_imports
4444
from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests
@@ -305,20 +305,8 @@ def optimize_function(self) -> Result[BestOptimization, str]:
305305
)
306306

307307
if not self.args.disable_imports_sorting:
308-
main_file_path = self.function_to_optimize.file_path
309-
if main_file_path.exists():
310-
current_main_content = main_file_path.read_text(encoding="utf8")
311-
sorted_main_content = sort_imports(current_main_content)
312-
if sorted_main_content != current_main_content:
313-
main_file_path.write_text(sorted_main_content, encoding="utf8")
314-
315-
writable_helper_file_paths = {hf.file_path for hf in code_context.helper_functions}
316-
for helper_file_path in writable_helper_file_paths:
317-
if helper_file_path.exists():
318-
current_helper_content = helper_file_path.read_text(encoding="utf8")
319-
sorted_helper_content = sort_imports(current_helper_content)
320-
if sorted_helper_content != current_helper_content:
321-
helper_file_path.write_text(sorted_helper_content, encoding="utf8")
308+
path_to_sort_imports_for = [self.function_to_optimize.file_path] + [hf.file_path for hf in code_context.helper_functions]
309+
sort_imports_in_place(path_to_sort_imports_for)
322310

323311
new_code = self.function_to_optimize.file_path.read_text(encoding="utf8")
324312
new_helper_code: dict[Path, str] = {}

0 commit comments

Comments
 (0)