Skip to content

Commit d904c23

Browse files
aseembits93Codeflash Bot
authored andcommitted
Merge remote-tracking branch 'origin/main' into cf-835
2 parents 4e8187c + 4a6eaab commit d904c23

File tree

3 files changed

+637
-9
lines changed

3 files changed

+637
-9
lines changed

codeflash/code_utils/formatter.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,19 @@ def is_diff_line(line: str) -> bool:
9696
return len(diff_lines)
9797

9898

99+
def format_generated_code(generated_test_source: str, formatter_cmds: list[str]) -> str:
100+
with tempfile.TemporaryDirectory() as test_dir_str:
101+
# try running formatter, if nothing changes (could be due to formatting failing or no actual formatting needed) return code with 2 or more newlines substituted with 2 newlines
102+
original_temp = Path(test_dir_str) / "original_temp.py"
103+
original_temp.write_text(generated_test_source, encoding="utf8")
104+
_, formatted_code, changed = apply_formatter_cmds(
105+
formatter_cmds, original_temp, test_dir_str, print_status=False, exit_on_failure=False
106+
)
107+
if not changed:
108+
return re.sub(r"\n{2,}", "\n\n", formatted_code)
109+
return formatted_code
110+
111+
99112
def format_code(
100113
formatter_cmds: list[str],
101114
path: Union[str, Path],
@@ -120,7 +133,7 @@ def format_code(
120133
original_code_lines = len(original_code.split("\n"))
121134

122135
if check_diff and original_code_lines > 50:
123-
# we dont' count the formatting diff for the optimized function as it should be well-formatted
136+
# we don't count the formatting diff for the optimized function as it should be well-formatted
124137
original_code_without_opfunc = original_code.replace(optimized_code, "")
125138

126139
original_temp = Path(test_dir_str) / "original_temp.py"

codeflash/optimization/function_optimizer.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
remove_functions_from_generated_tests,
5656
)
5757
from codeflash.code_utils.env_utils import get_pr_number
58-
from codeflash.code_utils.formatter import format_code, sort_imports
58+
from codeflash.code_utils.formatter import format_code, format_generated_code, sort_imports
5959
from codeflash.code_utils.git_utils import git_root_dir
6060
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
6161
from codeflash.code_utils.line_profile_utils import add_decorator_imports
@@ -1417,11 +1417,15 @@ def process_review(
14171417
generated_tests_str = ""
14181418
for test in generated_tests.generated_tests:
14191419
if map_gen_test_file_to_no_of_tests[test.behavior_file_path] > 0:
1420-
generated_tests_str += f"```python\n{test.generated_original_test_source}\n```"
1420+
formatted_generated_test = format_generated_code(
1421+
test.generated_original_test_source, self.args.formatter_cmds
1422+
)
1423+
generated_tests_str += f"```python\n{formatted_generated_test}\n```"
14211424
generated_tests_str += "\n\n"
14221425

14231426
if concolic_test_str:
1424-
generated_tests_str += f"```python\n{concolic_test_str}\n```\n\n"
1427+
formatted_generated_test = format_generated_code(concolic_test_str, self.args.formatter_cmds)
1428+
generated_tests_str += f"```python\n{formatted_generated_test}\n```\n\n"
14251429

14261430
existing_tests, replay_tests, concolic_tests = existing_tests_source_for(
14271431
self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root),
@@ -1569,8 +1573,7 @@ def establish_original_code_baseline(
15691573
) -> Result[tuple[OriginalCodeBaseline, list[str]], str]:
15701574
line_profile_results = {"timings": {}, "unit": 0, "str_out": ""}
15711575
# For the original function - run the tests and get the runtime, plus coverage
1572-
test_framework = self.args.test_framework
1573-
assert test_framework in {"pytest", "unittest"}
1576+
assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} # noqa: RUF018
15741577
success = True
15751578

15761579
test_env = self.get_test_env(codeflash_loop_index=0, codeflash_test_iteration=0, codeflash_tracer_disable=1)
@@ -1748,8 +1751,7 @@ def run_optimized_candidate(
17481751
original_helper_code: dict[Path, str],
17491752
file_path_to_helper_classes: dict[Path, set[str]],
17501753
) -> Result[OptimizedCandidateResult, str]:
1751-
test_framework = self.args.test_framework
1752-
assert test_framework in {"pytest", "unittest"}
1754+
assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} # noqa: RUF018
17531755

17541756
with progress_bar("Testing optimization candidate"):
17551757
test_env = self.get_test_env(

0 commit comments

Comments
 (0)