diff --git a/code_to_optimize/bad_formatting.py b/code_to_optimize/bad_formatting.py new file mode 100644 index 000000000..00e3c5070 --- /dev/null +++ b/code_to_optimize/bad_formatting.py @@ -0,0 +1,43 @@ +import sys + + +def lol(): + print( "lol" ) + + + + + + + + + +class BubbleSorter: + def __init__(self, x=0): + self.x = x + + def lol(self): + print( "lol" ) + + + + + + + + + def sorter (self, arr): + + + print ("codeflash stdout : BubbleSorter.sorter() called") + n = len(arr) + for i in range(n): + swapped = False + for j in range(0, n - i - 1): + if arr[j] > arr[j + 1]: + arr[j], arr[j + 1] = arr[j + 1], arr[j] # Faster swap + swapped = True + if not swapped: + break + print ("stderr test", file=sys.stderr) + return arr \ No newline at end of file diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 3f0d72bcd..37bd298cb 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -80,6 +80,29 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None: ending_line=pos.end.line, ) ) +class CodeRangeFunctionVisitor(cst.CSTVisitor): + METADATA_DEPENDENCIES = ( + cst.metadata.PositionProvider, + cst.metadata.QualifiedNameProvider, + ) + + def __init__(self, target_function_name: str) -> None: + super().__init__() + self.target_func = target_function_name + self.start_line: Optional[int] = None + self.end_line: Optional[int] = None + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: + qualified_names = [ + str(qn.name).replace(".", "") for qn in + self.get_metadata(cst.metadata.QualifiedNameProvider, node) + ] + if self.target_func in qualified_names: + func_position = self.get_metadata(cst.metadata.PositionProvider, node) + decorators_count = len(node.decorators) + self.start_line = func_position.start.line - decorators_count + self.end_line = func_position.end.line + return False class FunctionWithReturnStatement(ast.NodeVisitor): diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index fe4357839..1d7448400 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -9,7 +9,7 @@ from collections import defaultdict, deque from pathlib import Path from typing import TYPE_CHECKING - +import tempfile import isort import libcst as cst from rich.console import Group @@ -72,6 +72,8 @@ from codeflash.verification.verification_utils import get_test_file_path from codeflash.verification.verifier import generate_tests +from codeflash.discovery.functions_to_optimize import CodeRangeFunctionVisitor + if TYPE_CHECKING: from argparse import Namespace @@ -301,7 +303,8 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 ) new_code, new_helper_code = self.reformat_code_and_helpers( - code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code + code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code, + opt_func_name=explanation.function_name ) existing_tests = existing_tests_source_for( @@ -590,25 +593,66 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path, f.write(helper_code) def reformat_code_and_helpers( - self, helper_functions: list[FunctionSource], path: Path, original_code: str + self, helper_functions: list[FunctionSource], + path: Path, + original_code: str, + opt_func_name: str ) -> tuple[str, dict[Path, str]]: should_sort_imports = not self.args.disable_imports_sorting if should_sort_imports and isort.code(original_code) != original_code: should_sort_imports = False - new_code = format_code(self.args.formatter_cmds, path) - if should_sort_imports: - new_code = sort_imports(new_code) - - new_helper_code: dict[Path, str] = {} - helper_functions_paths = {hf.file_path for hf in helper_functions} - for module_abspath in helper_functions_paths: - formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath) - if should_sort_imports: - formatted_helper_code = sort_imports(formatted_helper_code) - new_helper_code[module_abspath] = formatted_helper_code - - return new_code, new_helper_code + whole_file_content = path.read_text(encoding="utf8") + wrapper: cst.metadata.MetadataWrapper | None = None + try: + wrapper = cst.metadata.MetadataWrapper(cst.parse_module(whole_file_content)) + except cst.ParserSyntaxError as e: + logger.error(f"Syntax error detected, aborting reformatting.") + return original_code, {} + + visitor = CodeRangeFunctionVisitor(target_function_name=opt_func_name) + wrapper.visit(visitor) + + lines = whole_file_content.splitlines(keepends=True) + if visitor.start_line == None: + logger.error(f"Could not find function {opt_func_name} in {path}, aborting reformatting.") + return original_code, {} + else: + opt_func_source_lines = lines[visitor.start_line-1:visitor.end_line] + + # fix opt func identation + first_line = opt_func_source_lines[0] + first_line_indent = len(first_line) - len(first_line.lstrip()) # number of spaces before the first character + opt_func_source_lines[0] = opt_func_source_lines[0][first_line_indent:] # remove first line ident, so when we save the function code into a temp file, we don't get syntax errors + + with tempfile.NamedTemporaryFile(mode='w+', delete=True) as f: + f.write("".join(opt_func_source_lines)) + f.flush() + tmp_file = Path(f.name) + formatted_func = format_code(self.args.formatter_cmds, tmp_file) + # apply the identation back to all lines of the formatted function + formatted_lines = formatted_func.splitlines(keepends=True) + for i in range(len(formatted_lines)): + formatted_lines[i] = (" " * first_line_indent) + formatted_lines[i] + + # replace the unformatted code with formatted ones + new_code = ( + "".join(lines[:visitor.start_line-1]) + + "".join(formatted_lines) + + "".join(lines[visitor.end_line:]) + ) + if should_sort_imports: + new_code = sort_imports(new_code) + + new_helper_code: dict[Path, str] = {} + helper_functions_paths = {hf.file_path for hf in helper_functions} + for module_abspath in helper_functions_paths: + formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath) + if should_sort_imports: + formatted_helper_code = sort_imports(formatted_helper_code) + new_helper_code[module_abspath] = formatted_helper_code + + return new_code, new_helper_code def replace_function_and_helpers_with_optimized_code( self, code_context: CodeOptimizationContext, optimized_code: str diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 5c0a91c38..261ca4233 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -1,3 +1,4 @@ +import argparse import os import tempfile from pathlib import Path @@ -7,6 +8,9 @@ from codeflash.code_utils.config_parser import parse_config_file from codeflash.code_utils.formatter import format_code, sort_imports +from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.optimization.function_optimizer import FunctionOptimizer +from codeflash.verification.verification_utils import TestConfig def test_remove_duplicate_imports(): """Test that duplicate imports are removed when should_sort_imports is True.""" @@ -209,3 +213,255 @@ def foo(): tmp_path = tmp.name with pytest.raises(FileNotFoundError): format_code(formatter_cmds=["exit 1"], path=Path(tmp_path)) + +############################################################ +################ CST based formatting tests ################ +############################################################ +@pytest.fixture +def setup_cst_formatter_args(): + """Common setup for reformat_code_and_helpers tests.""" + def _setup(unformatted_code, function_name): + test_dir = Path(tempfile.mkdtemp()) + target_path = test_dir / "target.py" + target_path.write_text(unformatted_code, encoding="utf-8") + + function_to_optimize = FunctionToOptimize( + function_name=function_name, + parents=[], + file_path=target_path + ) + + test_cfg = TestConfig( + tests_root=test_dir, + project_root_path=test_dir, + test_framework="pytest", + tests_project_rootdir=test_dir, + ) + + args = argparse.Namespace( + disable_imports_sorting=False, + formatter_cmds=[ + "ruff check --exit-zero --fix $file", + "ruff format $file" + ], + ) + + optimizer = FunctionOptimizer( + function_to_optimize=function_to_optimize, + test_cfg=test_cfg, + args=args, + ) + + return optimizer, target_path, function_to_optimize + + yield _setup + + +def test_reformat_code_and_helpers(setup_cst_formatter_args): + """ + reformat_code_and_helpers should only format the code that is optimized not the whole file, to avoid large diffing + """ + unformatted_code = """import sys + + +def lol(): + print( "lol" ) + + + + +class MyClass: + def __init__(self, x=0): + self.x = x + + def lol(self): + print( "lol" ) + + def lol2 (self): + print( " lol2" )""" + + expected_code = """import sys + + +def lol(): + print( "lol" ) + + + + +class MyClass: + def __init__(self, x=0): + self.x = x + + def lol(self): + print( "lol" ) + + def lol2(self): + print(" lol2") +""" + + optimizer, target_path, function_to_optimize = setup_cst_formatter_args( + unformatted_code, "MyClass.lol2" + ) + + formatted_code, _ = optimizer.reformat_code_and_helpers( + helper_functions=[], + path=target_path, + original_code=optimizer.function_to_optimize_source_code, + opt_func_name=function_to_optimize.function_name + ) + + assert formatted_code == expected_code + + +def test_reformat_code_and_helpers_with_duplicated_target_function_names(setup_cst_formatter_args): + unformatted_code = """import sys +def lol(): + print( "lol" ) + +class MyClass: + def __init__(self, x=0): + self.x = x + + def lol(self): + print( "lol" )""" + + expected_code = """import sys +def lol(): + print( "lol" ) + +class MyClass: + def __init__(self, x=0): + self.x = x + + def lol(self): + print("lol") +""" + + optimizer, target_path, function_to_optimize = setup_cst_formatter_args( + unformatted_code, "MyClass.lol" + ) + + formatted_code, _ = optimizer.reformat_code_and_helpers( + helper_functions=[], + path=target_path, + original_code=optimizer.function_to_optimize_source_code, + opt_func_name=function_to_optimize.function_name + ) + + assert formatted_code == expected_code + + + +def test_formatting_nested_functions(setup_cst_formatter_args): + unformatted_code = """def hello(): + print("Hello") + def nested_function() : + print ("This is a nested function") + def another_nested_function(): + print ("This is another nested function")""" + + expected_code = """def hello(): + print("Hello") + def nested_function(): + print("This is a nested function") + def another_nested_function(): + print ("This is another nested function")""" + + optimizer, target_path, function_to_optimize = setup_cst_formatter_args( + unformatted_code, "hello.nested_function" + ) + + formatted_code, _ = optimizer.reformat_code_and_helpers( + helper_functions=[], + path=target_path, + original_code=optimizer.function_to_optimize_source_code, + opt_func_name=function_to_optimize.function_name + ) + + assert formatted_code == expected_code + + +def test_formatting_standalone_functions(setup_cst_formatter_args): + unformatted_code = """def func1 (): + print( "This is a function with bad formatting") +def func2() : + print ( "This is another function with bad formatting" ) +""" + + expected_code = """def func1 (): + print( "This is a function with bad formatting") +def func2(): + print("This is another function with bad formatting") +""" + + optimizer, target_path, function_to_optimize = setup_cst_formatter_args( + unformatted_code, "func2" + ) + + formatted_code, _ = optimizer.reformat_code_and_helpers( + helper_functions=[], + path=target_path, + original_code=optimizer.function_to_optimize_source_code, + opt_func_name=function_to_optimize.function_name + ) + + assert formatted_code == expected_code + + +def test_formatting_function_with_decorators(setup_cst_formatter_args): + unformatted_code = """@decorator1 +@decorator2( arg1 , arg2 ) +def func1 (): + print( "This is a function with bad formatting") + +@another_decorator( arg) +def func2 ( x,y ): + print ( "This is another function with bad formatting" )""" + + expected_code = """@decorator1 +@decorator2( arg1 , arg2 ) +def func1 (): + print( "This is a function with bad formatting") + +@another_decorator(arg) +def func2(x, y): + print("This is another function with bad formatting") +""" + + optimizer, target_path, function_to_optimize = setup_cst_formatter_args( + unformatted_code, "func2" + ) + + formatted_code, _ = optimizer.reformat_code_and_helpers( + helper_functions=[], + path=target_path, + original_code=optimizer.function_to_optimize_source_code, + opt_func_name=function_to_optimize.function_name + ) + + assert formatted_code == expected_code + + +def test_formatting_function_with_syntax_error(setup_cst_formatter_args): + """shouldn't happen anyway, but just in case""" + unformatted_code = """def func1(): + print("This is a function with a syntax error" +def func2(): + print("This is another function with a syntax error") +""" + + expected_code = unformatted_code # No formatting should be applied due to syntax error + + optimizer, target_path, function_to_optimize = setup_cst_formatter_args( + unformatted_code, "func2" + ) + + formatted_code, _ = optimizer.reformat_code_and_helpers( + helper_functions=[], + path=target_path, + original_code=optimizer.function_to_optimize_source_code, + opt_func_name=function_to_optimize.function_name + ) + + assert formatted_code == expected_code