|
5 | 5 | import os |
6 | 6 | import shutil |
7 | 7 | import subprocess |
| 8 | +import tempfile |
8 | 9 | import time |
9 | 10 | import uuid |
10 | 11 | from collections import defaultdict, deque |
@@ -124,6 +125,7 @@ def __init__( |
124 | 125 | self.function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else {} |
125 | 126 | self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {} |
126 | 127 | self.replay_tests_dir = replay_tests_dir if replay_tests_dir else None |
| 128 | + self.optimizer_temp_dir = Path(tempfile.mkdtemp(prefix="codeflash_opt_fmt_")) |
127 | 129 |
|
128 | 130 | def optimize_function(self) -> Result[BestOptimization, str]: |
129 | 131 | should_run_experiment = self.experiment_id is not None |
@@ -301,9 +303,30 @@ def optimize_function(self) -> Result[BestOptimization, str]: |
301 | 303 | code_context=code_context, optimized_code=best_optimization.candidate.source_code |
302 | 304 | ) |
303 | 305 |
|
304 | | - new_code, new_helper_code = self.reformat_code_and_helpers( |
305 | | - code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code |
306 | | - ) |
| 306 | + if not self.args.disable_imports_sorting: |
| 307 | + main_file_path = self.function_to_optimize.file_path |
| 308 | + if main_file_path.exists(): |
| 309 | + current_main_content = main_file_path.read_text(encoding="utf8") |
| 310 | + sorted_main_content = sort_imports(current_main_content) |
| 311 | + if sorted_main_content != current_main_content: |
| 312 | + main_file_path.write_text(sorted_main_content, encoding="utf8") |
| 313 | + |
| 314 | + writable_helper_file_paths = {hf.file_path for hf in code_context.helper_functions} |
| 315 | + for helper_file_path in writable_helper_file_paths: |
| 316 | + if helper_file_path.exists(): |
| 317 | + current_helper_content = helper_file_path.read_text(encoding="utf8") |
| 318 | + sorted_helper_content = sort_imports(current_helper_content) |
| 319 | + if sorted_helper_content != current_helper_content: |
| 320 | + helper_file_path.write_text(sorted_helper_content, encoding="utf8") |
| 321 | + |
| 322 | + new_code = self.function_to_optimize.file_path.read_text(encoding="utf8") |
| 323 | + new_helper_code: dict[Path, str] = {} |
| 324 | + for helper_file_path_key in original_helper_code: |
| 325 | + if helper_file_path_key.exists(): |
| 326 | + new_helper_code[helper_file_path_key] = helper_file_path_key.read_text(encoding="utf8") |
| 327 | + else: |
| 328 | + logger.warning(f"Helper file {helper_file_path_key} not found after optimization. It will not be included in new_helper_code for PR.") |
| 329 | + |
307 | 330 |
|
308 | 331 | existing_tests = existing_tests_source_for( |
309 | 332 | self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), |
@@ -405,6 +428,33 @@ def determine_best_candidate( |
405 | 428 | future_line_profile_results = None |
406 | 429 | candidate_index += 1 |
407 | 430 | candidate = candidates.popleft() |
| 431 | + |
| 432 | + formatted_candidate_code = candidate.source_code |
| 433 | + if self.args.formatter_cmds: |
| 434 | + temp_code_file_path: Path | None = None |
| 435 | + try: |
| 436 | + with tempfile.NamedTemporaryFile( |
| 437 | + mode="w", |
| 438 | + suffix=".py", |
| 439 | + delete=False, |
| 440 | + encoding="utf8", |
| 441 | + dir=self.optimizer_temp_dir |
| 442 | + ) as tmp_file: |
| 443 | + tmp_file.write(candidate.source_code) |
| 444 | + temp_code_file_path = Path(tmp_file.name) |
| 445 | + |
| 446 | + formatted_candidate_code = format_code( |
| 447 | + formatter_cmds=self.args.formatter_cmds, |
| 448 | + path=temp_code_file_path |
| 449 | + ) |
| 450 | + except Exception as e: |
| 451 | + logger.error(f"Error during formatting candidate code via temp file: {e}. Using original candidate code.") |
| 452 | + finally: |
| 453 | + if temp_code_file_path and temp_code_file_path.exists(): |
| 454 | + temp_code_file_path.unlink(missing_ok=True) |
| 455 | + |
| 456 | + candidate.source_code = formatted_candidate_code |
| 457 | + |
408 | 458 | get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True) |
409 | 459 | get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True) |
410 | 460 | logger.info(f"Optimization candidate {candidate_index}/{original_len}:") |
@@ -580,27 +630,6 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path, |
580 | 630 | with Path(module_abspath).open("w", encoding="utf8") as f: |
581 | 631 | f.write(original_helper_code[module_abspath]) |
582 | 632 |
|
583 | | - def reformat_code_and_helpers( |
584 | | - self, helper_functions: list[FunctionSource], path: Path, original_code: str |
585 | | - ) -> tuple[str, dict[Path, str]]: |
586 | | - should_sort_imports = not self.args.disable_imports_sorting |
587 | | - if should_sort_imports and isort.code(original_code) != original_code: |
588 | | - should_sort_imports = False |
589 | | - |
590 | | - new_code = format_code(self.args.formatter_cmds, path) |
591 | | - if should_sort_imports: |
592 | | - new_code = sort_imports(new_code) |
593 | | - |
594 | | - new_helper_code: dict[Path, str] = {} |
595 | | - helper_functions_paths = {hf.file_path for hf in helper_functions} |
596 | | - for module_abspath in helper_functions_paths: |
597 | | - formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath) |
598 | | - if should_sort_imports: |
599 | | - formatted_helper_code = sort_imports(formatted_helper_code) |
600 | | - new_helper_code[module_abspath] = formatted_helper_code |
601 | | - |
602 | | - return new_code, new_helper_code |
603 | | - |
604 | 633 | def replace_function_and_helpers_with_optimized_code( |
605 | 634 | self, code_context: CodeOptimizationContext, optimized_code: str |
606 | 635 | ) -> bool: |
|
0 commit comments