Skip to content

Commit 3cbd6b7

Browse files
committed
feat(optimizer): Implement targeted formatting (CF-637)
1 parent 198595c commit 3cbd6b7

File tree

1 file changed

+53
-24
lines changed

1 file changed

+53
-24
lines changed

codeflash/optimization/function_optimizer.py

Lines changed: 53 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import shutil
77
import subprocess
8+
import tempfile
89
import time
910
import uuid
1011
from collections import defaultdict, deque
@@ -124,6 +125,7 @@ def __init__(
124125
self.function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else {}
125126
self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {}
126127
self.replay_tests_dir = replay_tests_dir if replay_tests_dir else None
128+
self.optimizer_temp_dir = Path(tempfile.mkdtemp(prefix="codeflash_opt_fmt_"))
127129

128130
def optimize_function(self) -> Result[BestOptimization, str]:
129131
should_run_experiment = self.experiment_id is not None
@@ -301,9 +303,30 @@ def optimize_function(self) -> Result[BestOptimization, str]:
301303
code_context=code_context, optimized_code=best_optimization.candidate.source_code
302304
)
303305

304-
new_code, new_helper_code = self.reformat_code_and_helpers(
305-
code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code
306-
)
306+
if not self.args.disable_imports_sorting:
307+
main_file_path = self.function_to_optimize.file_path
308+
if main_file_path.exists():
309+
current_main_content = main_file_path.read_text(encoding="utf8")
310+
sorted_main_content = sort_imports(current_main_content)
311+
if sorted_main_content != current_main_content:
312+
main_file_path.write_text(sorted_main_content, encoding="utf8")
313+
314+
writable_helper_file_paths = {hf.file_path for hf in code_context.helper_functions}
315+
for helper_file_path in writable_helper_file_paths:
316+
if helper_file_path.exists():
317+
current_helper_content = helper_file_path.read_text(encoding="utf8")
318+
sorted_helper_content = sort_imports(current_helper_content)
319+
if sorted_helper_content != current_helper_content:
320+
helper_file_path.write_text(sorted_helper_content, encoding="utf8")
321+
322+
new_code = self.function_to_optimize.file_path.read_text(encoding="utf8")
323+
new_helper_code: dict[Path, str] = {}
324+
for helper_file_path_key in original_helper_code:
325+
if helper_file_path_key.exists():
326+
new_helper_code[helper_file_path_key] = helper_file_path_key.read_text(encoding="utf8")
327+
else:
328+
logger.warning(f"Helper file {helper_file_path_key} not found after optimization. It will not be included in new_helper_code for PR.")
329+
307330

308331
existing_tests = existing_tests_source_for(
309332
self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root),
@@ -405,6 +428,33 @@ def determine_best_candidate(
405428
future_line_profile_results = None
406429
candidate_index += 1
407430
candidate = candidates.popleft()
431+
432+
formatted_candidate_code = candidate.source_code
433+
if self.args.formatter_cmds:
434+
temp_code_file_path: Path | None = None
435+
try:
436+
with tempfile.NamedTemporaryFile(
437+
mode="w",
438+
suffix=".py",
439+
delete=False,
440+
encoding="utf8",
441+
dir=self.optimizer_temp_dir
442+
) as tmp_file:
443+
tmp_file.write(candidate.source_code)
444+
temp_code_file_path = Path(tmp_file.name)
445+
446+
formatted_candidate_code = format_code(
447+
formatter_cmds=self.args.formatter_cmds,
448+
path=temp_code_file_path
449+
)
450+
except Exception as e:
451+
logger.error(f"Error during formatting candidate code via temp file: {e}. Using original candidate code.")
452+
finally:
453+
if temp_code_file_path and temp_code_file_path.exists():
454+
temp_code_file_path.unlink(missing_ok=True)
455+
456+
candidate.source_code = formatted_candidate_code
457+
408458
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
409459
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
410460
logger.info(f"Optimization candidate {candidate_index}/{original_len}:")
@@ -580,27 +630,6 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path,
580630
with Path(module_abspath).open("w", encoding="utf8") as f:
581631
f.write(original_helper_code[module_abspath])
582632

583-
def reformat_code_and_helpers(
584-
self, helper_functions: list[FunctionSource], path: Path, original_code: str
585-
) -> tuple[str, dict[Path, str]]:
586-
should_sort_imports = not self.args.disable_imports_sorting
587-
if should_sort_imports and isort.code(original_code) != original_code:
588-
should_sort_imports = False
589-
590-
new_code = format_code(self.args.formatter_cmds, path)
591-
if should_sort_imports:
592-
new_code = sort_imports(new_code)
593-
594-
new_helper_code: dict[Path, str] = {}
595-
helper_functions_paths = {hf.file_path for hf in helper_functions}
596-
for module_abspath in helper_functions_paths:
597-
formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath)
598-
if should_sort_imports:
599-
formatted_helper_code = sort_imports(formatted_helper_code)
600-
new_helper_code[module_abspath] = formatted_helper_code
601-
602-
return new_code, new_helper_code
603-
604633
def replace_function_and_helpers_with_optimized_code(
605634
self, code_context: CodeOptimizationContext, optimized_code: str
606635
) -> bool:

0 commit comments

Comments
 (0)