|
55 | 55 | remove_functions_from_generated_tests, |
56 | 56 | ) |
57 | 57 | from codeflash.code_utils.env_utils import get_pr_number |
58 | | -from codeflash.code_utils.formatter import format_code, sort_imports |
| 58 | +from codeflash.code_utils.formatter import format_code, format_generated_code, sort_imports |
59 | 59 | from codeflash.code_utils.git_utils import git_root_dir |
60 | 60 | from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test |
61 | 61 | from codeflash.code_utils.line_profile_utils import add_decorator_imports |
@@ -1417,11 +1417,15 @@ def process_review( |
1417 | 1417 | generated_tests_str = "" |
1418 | 1418 | for test in generated_tests.generated_tests: |
1419 | 1419 | if map_gen_test_file_to_no_of_tests[test.behavior_file_path] > 0: |
1420 | | - generated_tests_str += f"```python\n{test.generated_original_test_source}\n```" |
| 1420 | + formatted_generated_test = format_generated_code( |
| 1421 | + test.generated_original_test_source, self.args.formatter_cmds |
| 1422 | + ) |
| 1423 | + generated_tests_str += f"```python\n{formatted_generated_test}\n```" |
1421 | 1424 | generated_tests_str += "\n\n" |
1422 | 1425 |
|
1423 | 1426 | if concolic_test_str: |
1424 | | - generated_tests_str += f"```python\n{concolic_test_str}\n```\n\n" |
| 1427 | + formatted_generated_test = format_generated_code(concolic_test_str, self.args.formatter_cmds) |
| 1428 | + generated_tests_str += f"```python\n{formatted_generated_test}\n```\n\n" |
1425 | 1429 |
|
1426 | 1430 | existing_tests, replay_tests, concolic_tests = existing_tests_source_for( |
1427 | 1431 | self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), |
@@ -1569,8 +1573,7 @@ def establish_original_code_baseline( |
1569 | 1573 | ) -> Result[tuple[OriginalCodeBaseline, list[str]], str]: |
1570 | 1574 | line_profile_results = {"timings": {}, "unit": 0, "str_out": ""} |
1571 | 1575 | # For the original function - run the tests and get the runtime, plus coverage |
1572 | | - test_framework = self.args.test_framework |
1573 | | - assert test_framework in {"pytest", "unittest"} |
| 1576 | + assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} # noqa: RUF018 |
1574 | 1577 | success = True |
1575 | 1578 |
|
1576 | 1579 | test_env = self.get_test_env(codeflash_loop_index=0, codeflash_test_iteration=0, codeflash_tracer_disable=1) |
@@ -1748,8 +1751,7 @@ def run_optimized_candidate( |
1748 | 1751 | original_helper_code: dict[Path, str], |
1749 | 1752 | file_path_to_helper_classes: dict[Path, set[str]], |
1750 | 1753 | ) -> Result[OptimizedCandidateResult, str]: |
1751 | | - test_framework = self.args.test_framework |
1752 | | - assert test_framework in {"pytest", "unittest"} |
| 1754 | + assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} # noqa: RUF018 |
1753 | 1755 |
|
1754 | 1756 | with progress_bar("Testing optimization candidate"): |
1755 | 1757 | test_env = self.get_test_env( |
|
0 commit comments