Skip to content

Commit 3f6e6c9

Browse files
make formatter optional in the optimizer only to not break other logic & fix unit tests
1 parent ecac80d commit 3f6e6c9

File tree

3 files changed

+41
-36
lines changed

3 files changed

+41
-36
lines changed

codeflash/code_utils/formatter.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def apply_formatter_cmds(
4444
test_dir_str: Optional[str],
4545
print_status: bool, # noqa
4646
exit_on_failure: bool = True, # noqa
47-
) -> tuple[Path, str]:
47+
) -> tuple[Path, str, bool]:
4848
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
4949
formatter_name = cmds[0].lower()
5050
should_make_copy = False
@@ -55,7 +55,7 @@ def apply_formatter_cmds(
5555
file_path = Path(test_dir_str) / "temp.py"
5656

5757
if not cmds or formatter_name == "disabled":
58-
return path, path.read_text(encoding="utf8")
58+
return path, path.read_text(encoding="utf8"), False
5959

6060
if not path.exists():
6161
msg = f"File {path} does not exist. Cannot apply formatter commands."
@@ -66,6 +66,7 @@ def apply_formatter_cmds(
6666

6767
file_token = "$file" # noqa: S105
6868

69+
changed = False
6970
for command in cmds:
7071
formatter_cmd_list = shlex.split(command, posix=os.name != "nt")
7172
formatter_cmd_list = [file_path.as_posix() if chunk == file_token else chunk for chunk in formatter_cmd_list]
@@ -74,6 +75,7 @@ def apply_formatter_cmds(
7475
if result.returncode == 0:
7576
if print_status:
7677
console.rule(f"Formatted Successfully with: {command.replace('$file', path.name)}")
78+
changed = True
7779
else:
7880
logger.error(f"Failed to format code with {' '.join(formatter_cmd_list)}")
7981
except FileNotFoundError as e:
@@ -85,13 +87,10 @@ def apply_formatter_cmds(
8587
expand=False,
8688
)
8789
console.print(panel)
88-
logger.warning(
89-
f"Formatter command not found: {' '.join(formatter_cmd_list)}, continuing without formatting"
90-
)
9190
if exit_on_failure:
9291
raise e from None
9392

94-
return file_path, file_path.read_text(encoding="utf8")
93+
return file_path, file_path.read_text(encoding="utf8"), changed
9594

9695

9796
def get_diff_lines_count(diff_output: str) -> int:
@@ -129,34 +128,32 @@ def format_code(
129128
original_temp = Path(test_dir_str) / "original_temp.py"
130129
original_temp.write_text(original_code_without_opfunc, encoding="utf8")
131130

132-
try:
133-
formatted_temp, formatted_code = apply_formatter_cmds(
134-
formatter_cmds, original_temp, test_dir_str, print_status=False
131+
formatted_temp, formatted_code, changed = apply_formatter_cmds(
132+
formatter_cmds, original_temp, test_dir_str, print_status=False, exit_on_failure=exit_on_failure
133+
)
134+
135+
if not changed:
136+
logger.warning(
137+
f"No changes detected in {path} after formatting, are you sure you have valid formatter commands?"
135138
)
139+
return original_code
140+
141+
diff_output = generate_unified_diff(
142+
original_code_without_opfunc, formatted_code, from_file=str(original_temp), to_file=str(formatted_temp)
143+
)
144+
diff_lines_count = get_diff_lines_count(diff_output)
145+
146+
max_diff_lines = min(int(original_code_lines * 0.3), 50)
136147

137-
diff_output = generate_unified_diff(
138-
original_code_without_opfunc,
139-
formatted_code,
140-
from_file=str(original_temp),
141-
to_file=str(formatted_temp),
148+
if diff_lines_count > max_diff_lines:
149+
logger.warning(
150+
f"Skipping formatting {path}: {diff_lines_count} lines would change (max: {max_diff_lines})"
142151
)
143-
diff_lines_count = get_diff_lines_count(diff_output)
144-
145-
max_diff_lines = min(int(original_code_lines * 0.3), 50)
146-
147-
if diff_lines_count > max_diff_lines:
148-
logger.debug(
149-
f"Skipping formatting {path}: {diff_lines_count} lines would change (max: {max_diff_lines})"
150-
)
151-
return original_code
152-
except FileNotFoundError as e:
153-
logger.warning(f"Formatter not found, skipping diff check: {e}")
154-
# Continue without formatting checks
155152
return original_code
156153

157154
# TODO : We can avoid formatting the whole file again and only formatting the optimized code standalone and replace in formatted file above.
158155
try:
159-
_, formatted_code = apply_formatter_cmds(
156+
_, formatted_code, _ = apply_formatter_cmds(
160157
formatter_cmds, path, test_dir_str=None, print_status=print_status, exit_on_failure=exit_on_failure
161158
)
162159
except FileNotFoundError as e:

codeflash/optimization/function_optimizer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,9 @@ def reformat_code_and_helpers(
745745
file_to_code_context = optimized_context.file_to_path()
746746
optimized_code = file_to_code_context.get(str(path.relative_to(self.project_root)), "")
747747

748-
new_code = format_code(self.args.formatter_cmds, path, optimized_code=optimized_code, check_diff=True)
748+
new_code = format_code(
749+
self.args.formatter_cmds, path, optimized_code=optimized_code, check_diff=True, exit_on_failure=False
750+
)
749751
if should_sort_imports:
750752
new_code = sort_imports(new_code)
751753

@@ -754,7 +756,11 @@ def reformat_code_and_helpers(
754756
module_abspath = hp.file_path
755757
hp_source_code = hp.source_code
756758
formatted_helper_code = format_code(
757-
self.args.formatter_cmds, module_abspath, optimized_code=hp_source_code, check_diff=True
759+
self.args.formatter_cmds,
760+
module_abspath,
761+
optimized_code=hp_source_code,
762+
check_diff=True,
763+
exit_on_failure=False,
758764
)
759765
if should_sort_imports:
760766
formatted_helper_code = sort_imports(formatted_helper_code)

tests/test_formatter.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,15 @@ def test_formatter_error():
208208
import sys
209209
def foo():
210210
return os.path.join(sys.path[0], 'bar')"""
211-
expected = original_code
212211
with tempfile.NamedTemporaryFile("w") as tmp:
213212
tmp.write(original_code)
214213
tmp.flush()
215214
tmp_path = tmp.name
216-
with pytest.raises(FileNotFoundError):
217-
format_code(formatter_cmds=["exit 1"], path=Path(tmp_path))
215+
try:
216+
new_code = format_code(formatter_cmds=["exit 1"], path=Path(tmp_path), exit_on_failure=False)
217+
assert new_code == original_code
218+
except Exception as e:
219+
assert False, f"Shouldn't throw an exception even if the formatter is not found: {e}"
218220

219221

220222
def _run_formatting_test(source_code: str, should_content_change: bool, expected = None, optimized_function: str = ""):
@@ -570,12 +572,12 @@ def main():
570572
def test_formatting_edge_case_exactly_100_diffs():
571573
"""Test behavior when exactly at the threshold of 100 changes."""
572574
# Create a file with exactly 100 minor formatting issues
573-
source_code = '''import json\n''' + '''
574-
def func{}():
575+
snippet = '''import json\n''' + '''
576+
def func_{i}():
575577
x=1;y=2;z=3
576578
return x+y+z
577-
'''.replace('{}', '_{i}').format(i='{i}') * 33 # This creates exactly 100 potential formatting fixes
578-
579+
'''
580+
source_code = "".join([snippet.format(i=i) for i in range(100)])
579581
_run_formatting_test(source_code, False)
580582

581583

0 commit comments

Comments
 (0)