Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions codeflash/code_utils/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,14 @@ def apply_formatter_cmds(
test_dir_str: Optional[str],
print_status: bool, # noqa
exit_on_failure: bool = True, # noqa
) -> tuple[Path, str]:
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
formatter_name = cmds[0].lower()
) -> tuple[Path, str, bool]:
should_make_copy = False
file_path = path

if test_dir_str:
should_make_copy = True
file_path = Path(test_dir_str) / "temp.py"

if not cmds or formatter_name == "disabled":
return path, path.read_text(encoding="utf8")

if not path.exists():
msg = f"File {path} does not exist. Cannot apply formatter commands."
raise FileNotFoundError(msg)
Expand All @@ -66,6 +61,7 @@ def apply_formatter_cmds(

file_token = "$file" # noqa: S105

changed = False
for command in cmds:
formatter_cmd_list = shlex.split(command, posix=os.name != "nt")
formatter_cmd_list = [file_path.as_posix() if chunk == file_token else chunk for chunk in formatter_cmd_list]
Expand All @@ -74,6 +70,7 @@ def apply_formatter_cmds(
if result.returncode == 0:
if print_status:
console.rule(f"Formatted Successfully with: {command.replace('$file', path.name)}")
changed = True
else:
logger.error(f"Failed to format code with {' '.join(formatter_cmd_list)}")
except FileNotFoundError as e:
Expand All @@ -88,7 +85,7 @@ def apply_formatter_cmds(
if exit_on_failure:
raise e from None

return file_path, file_path.read_text(encoding="utf8")
return file_path, file_path.read_text(encoding="utf8"), changed


def get_diff_lines_count(diff_output: str) -> int:
Expand All @@ -112,6 +109,12 @@ def format_code(
if console.quiet:
# lsp mode
exit_on_failure = False

# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled"
if formatter_name == "disabled":
return path.read_text(encoding="utf8")

with tempfile.TemporaryDirectory() as test_dir_str:
if isinstance(path, str):
path = Path(path)
Expand All @@ -126,26 +129,39 @@ def format_code(
original_temp = Path(test_dir_str) / "original_temp.py"
original_temp.write_text(original_code_without_opfunc, encoding="utf8")

formatted_temp, formatted_code = apply_formatter_cmds(
formatter_cmds, original_temp, test_dir_str, print_status=False
formatted_temp, formatted_code, changed = apply_formatter_cmds(
formatter_cmds, original_temp, test_dir_str, print_status=False, exit_on_failure=exit_on_failure
)

if not changed:
logger.warning(
f"No changes detected in {path} after formatting, are you sure you have valid formatter commands?"
)
return original_code

diff_output = generate_unified_diff(
original_code_without_opfunc, formatted_code, from_file=str(original_temp), to_file=str(formatted_temp)
)
diff_lines_count = get_diff_lines_count(diff_output)

max_diff_lines = min(int(original_code_lines * 0.3), 50)

if diff_lines_count > max_diff_lines and max_diff_lines != -1:
logger.debug(
if diff_lines_count > max_diff_lines:
logger.warning(
f"Skipping formatting {path}: {diff_lines_count} lines would change (max: {max_diff_lines})"
)
return original_code

# TODO : We can avoid formatting the whole file again and only formatting the optimized code standalone and replace in formatted file above.
_, formatted_code = apply_formatter_cmds(
_, formatted_code, changed = apply_formatter_cmds(
formatter_cmds, path, test_dir_str=None, print_status=print_status, exit_on_failure=exit_on_failure
)
if not changed:
logger.warning(
f"No changes detected in {path} after formatting, are you sure you have valid formatter commands?"
)
return original_code

logger.debug(f"Formatted {path} with commands: {formatter_cmds}")
return formatted_code

Expand Down
10 changes: 8 additions & 2 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,9 @@ def reformat_code_and_helpers(
file_to_code_context = optimized_context.file_to_path()
optimized_code = file_to_code_context.get(str(path.relative_to(self.project_root)), "")

new_code = format_code(self.args.formatter_cmds, path, optimized_code=optimized_code, check_diff=True)
new_code = format_code(
self.args.formatter_cmds, path, optimized_code=optimized_code, check_diff=True, exit_on_failure=False
)
if should_sort_imports:
new_code = sort_imports(new_code)

Expand All @@ -754,7 +756,11 @@ def reformat_code_and_helpers(
module_abspath = hp.file_path
hp_source_code = hp.source_code
formatted_helper_code = format_code(
self.args.formatter_cmds, module_abspath, optimized_code=hp_source_code, check_diff=True
self.args.formatter_cmds,
module_abspath,
optimized_code=hp_source_code,
check_diff=True,
exit_on_failure=False,
)
if should_sort_imports:
formatted_helper_code = sort_imports(formatted_helper_code)
Expand Down
16 changes: 9 additions & 7 deletions tests/test_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,15 @@ def test_formatter_error():
import sys
def foo():
return os.path.join(sys.path[0], 'bar')"""
expected = original_code
with tempfile.NamedTemporaryFile("w") as tmp:
tmp.write(original_code)
tmp.flush()
tmp_path = tmp.name
with pytest.raises(FileNotFoundError):
format_code(formatter_cmds=["exit 1"], path=Path(tmp_path))
try:
new_code = format_code(formatter_cmds=["exit 1"], path=Path(tmp_path), exit_on_failure=False)
assert new_code == original_code
except Exception as e:
assert False, f"Shouldn't throw an exception even if the formatter is not found: {e}"


def _run_formatting_test(source_code: str, should_content_change: bool, expected = None, optimized_function: str = ""):
Expand Down Expand Up @@ -570,12 +572,12 @@ def main():
def test_formatting_edge_case_exactly_100_diffs():
"""Test behavior when exactly at the threshold of 100 changes."""
# Create a file with exactly 100 minor formatting issues
source_code = '''import json\n''' + '''
def func{}():
snippet = '''import json\n''' + '''
def func_{i}():
x=1;y=2;z=3
return x+y+z
'''.replace('{}', '_{i}').format(i='{i}') * 33 # This creates exactly 100 potential formatting fixes

'''
source_code = "".join([snippet.format(i=i) for i in range(100)])
_run_formatting_test(source_code, False)


Expand Down
Loading