diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index db7fa4257..e3a412734 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -44,9 +44,7 @@ def apply_formatter_cmds( test_dir_str: Optional[str], print_status: bool, # noqa exit_on_failure: bool = True, # noqa -) -> tuple[Path, str]: - # TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution - formatter_name = cmds[0].lower() +) -> tuple[Path, str, bool]: should_make_copy = False file_path = path @@ -54,9 +52,6 @@ def apply_formatter_cmds( should_make_copy = True file_path = Path(test_dir_str) / "temp.py" - if not cmds or formatter_name == "disabled": - return path, path.read_text(encoding="utf8") - if not path.exists(): msg = f"File {path} does not exist. Cannot apply formatter commands." raise FileNotFoundError(msg) @@ -66,6 +61,7 @@ def apply_formatter_cmds( file_token = "$file" # noqa: S105 + changed = False for command in cmds: formatter_cmd_list = shlex.split(command, posix=os.name != "nt") formatter_cmd_list = [file_path.as_posix() if chunk == file_token else chunk for chunk in formatter_cmd_list] @@ -74,6 +70,7 @@ def apply_formatter_cmds( if result.returncode == 0: if print_status: console.rule(f"Formatted Successfully with: {command.replace('$file', path.name)}") + changed = True else: logger.error(f"Failed to format code with {' '.join(formatter_cmd_list)}") except FileNotFoundError as e: @@ -88,7 +85,7 @@ def apply_formatter_cmds( if exit_on_failure: raise e from None - return file_path, file_path.read_text(encoding="utf8") + return file_path, file_path.read_text(encoding="utf8"), changed def get_diff_lines_count(diff_output: str) -> int: @@ -112,10 +109,16 @@ def format_code( if console.quiet: # lsp mode exit_on_failure = False - with tempfile.TemporaryDirectory() as test_dir_str: - if isinstance(path, str): - path = Path(path) + if isinstance(path, str): + path = Path(path) + + # TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution + formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled" + if formatter_name == "disabled": + return path.read_text(encoding="utf8") + + with tempfile.TemporaryDirectory() as test_dir_str: original_code = path.read_text(encoding="utf8") original_code_lines = len(original_code.split("\n")) @@ -126,10 +129,16 @@ def format_code( original_temp = Path(test_dir_str) / "original_temp.py" original_temp.write_text(original_code_without_opfunc, encoding="utf8") - formatted_temp, formatted_code = apply_formatter_cmds( - formatter_cmds, original_temp, test_dir_str, print_status=False + formatted_temp, formatted_code, changed = apply_formatter_cmds( + formatter_cmds, original_temp, test_dir_str, print_status=False, exit_on_failure=exit_on_failure ) + if not changed: + logger.warning( + f"No changes detected in {path} after formatting, are you sure you have valid formatter commands?" + ) + return original_code + diff_output = generate_unified_diff( original_code_without_opfunc, formatted_code, from_file=str(original_temp), to_file=str(formatted_temp) ) @@ -137,15 +146,22 @@ def format_code( max_diff_lines = min(int(original_code_lines * 0.3), 50) - if diff_lines_count > max_diff_lines and max_diff_lines != -1: - logger.debug( + if diff_lines_count > max_diff_lines: + logger.warning( f"Skipping formatting {path}: {diff_lines_count} lines would change (max: {max_diff_lines})" ) return original_code + # TODO : We can avoid formatting the whole file again and only formatting the optimized code standalone and replace in formatted file above. - _, formatted_code = apply_formatter_cmds( + _, formatted_code, changed = apply_formatter_cmds( formatter_cmds, path, test_dir_str=None, print_status=print_status, exit_on_failure=exit_on_failure ) + if not changed: + logger.warning( + f"No changes detected in {path} after formatting, are you sure you have valid formatter commands?" + ) + return original_code + logger.debug(f"Formatted {path} with commands: {formatter_cmds}") return formatted_code diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 320fc1d77..0253c9ac3 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -745,7 +745,9 @@ def reformat_code_and_helpers( file_to_code_context = optimized_context.file_to_path() optimized_code = file_to_code_context.get(str(path.relative_to(self.project_root)), "") - new_code = format_code(self.args.formatter_cmds, path, optimized_code=optimized_code, check_diff=True) + new_code = format_code( + self.args.formatter_cmds, path, optimized_code=optimized_code, check_diff=True, exit_on_failure=False + ) if should_sort_imports: new_code = sort_imports(new_code) @@ -754,7 +756,11 @@ def reformat_code_and_helpers( module_abspath = hp.file_path hp_source_code = hp.source_code formatted_helper_code = format_code( - self.args.formatter_cmds, module_abspath, optimized_code=hp_source_code, check_diff=True + self.args.formatter_cmds, + module_abspath, + optimized_code=hp_source_code, + check_diff=True, + exit_on_failure=False, ) if should_sort_imports: formatted_helper_code = sort_imports(formatted_helper_code) diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 5afc4630e..81a3c4f14 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -208,13 +208,15 @@ def test_formatter_error(): import sys def foo(): return os.path.join(sys.path[0], 'bar')""" - expected = original_code with tempfile.NamedTemporaryFile("w") as tmp: tmp.write(original_code) tmp.flush() tmp_path = tmp.name - with pytest.raises(FileNotFoundError): - format_code(formatter_cmds=["exit 1"], path=Path(tmp_path)) + try: + new_code = format_code(formatter_cmds=["exit 1"], path=Path(tmp_path), exit_on_failure=False) + assert new_code == original_code + except Exception as e: + assert False, f"Shouldn't throw an exception even if the formatter is not found: {e}" def _run_formatting_test(source_code: str, should_content_change: bool, expected = None, optimized_function: str = ""): @@ -570,12 +572,12 @@ def main(): def test_formatting_edge_case_exactly_100_diffs(): """Test behavior when exactly at the threshold of 100 changes.""" # Create a file with exactly 100 minor formatting issues - source_code = '''import json\n''' + ''' -def func{}(): + snippet = '''import json\n''' + ''' +def func_{i}(): x=1;y=2;z=3 return x+y+z -'''.replace('{}', '_{i}').format(i='{i}') * 33 # This creates exactly 100 potential formatting fixes - +''' + source_code = "".join([snippet.format(i=i) for i in range(100)]) _run_formatting_test(source_code, False)