Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions code_to_optimize/bad_formatting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import sys


def lol():
print( "lol" )









class BubbleSorter:
def __init__(self, x=0):
self.x = x

def lol(self):
print( "lol" )








def sorter (self, arr):


print ("codeflash stdout : BubbleSorter.sorter() called")
n = len(arr)
for i in range(n):
swapped = False
for j in range(0, n - i - 1):
if arr[j] > arr[j + 1]:
arr[j], arr[j + 1] = arr[j + 1], arr[j] # Faster swap
swapped = True
if not swapped:
break
print ("stderr test", file=sys.stderr)
return arr
23 changes: 23 additions & 0 deletions codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,29 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
ending_line=pos.end.line,
)
)
class CodeRangeFunctionVisitor(cst.CSTVisitor):
METADATA_DEPENDENCIES = (
cst.metadata.PositionProvider,
cst.metadata.QualifiedNameProvider,
)

def __init__(self, target_function_name: str) -> None:
super().__init__()
self.target_func = target_function_name
self.start_line: Optional[int] = None
self.end_line: Optional[int] = None

def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
qualified_names = [
str(qn.name).replace(".<locals>", "") for qn in
self.get_metadata(cst.metadata.QualifiedNameProvider, node)
]
if self.target_func in qualified_names:
func_position = self.get_metadata(cst.metadata.PositionProvider, node)
decorators_count = len(node.decorators)
self.start_line = func_position.start.line - decorators_count
self.end_line = func_position.end.line
return False


class FunctionWithReturnStatement(ast.NodeVisitor):
Expand Down
76 changes: 60 additions & 16 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections import defaultdict, deque
from pathlib import Path
from typing import TYPE_CHECKING

import tempfile
import isort
import libcst as cst
from rich.console import Group
Expand Down Expand Up @@ -72,6 +72,8 @@
from codeflash.verification.verification_utils import get_test_file_path
from codeflash.verification.verifier import generate_tests

from codeflash.discovery.functions_to_optimize import CodeRangeFunctionVisitor

if TYPE_CHECKING:
from argparse import Namespace

Expand Down Expand Up @@ -301,7 +303,8 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
)

new_code, new_helper_code = self.reformat_code_and_helpers(
code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code
code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code,
opt_func_name=explanation.function_name
)

existing_tests = existing_tests_source_for(
Expand Down Expand Up @@ -590,25 +593,66 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path,
f.write(helper_code)

def reformat_code_and_helpers(
self, helper_functions: list[FunctionSource], path: Path, original_code: str
self, helper_functions: list[FunctionSource],
path: Path,
original_code: str,
opt_func_name: str
) -> tuple[str, dict[Path, str]]:
should_sort_imports = not self.args.disable_imports_sorting
if should_sort_imports and isort.code(original_code) != original_code:
should_sort_imports = False

new_code = format_code(self.args.formatter_cmds, path)
if should_sort_imports:
new_code = sort_imports(new_code)

new_helper_code: dict[Path, str] = {}
helper_functions_paths = {hf.file_path for hf in helper_functions}
for module_abspath in helper_functions_paths:
formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath)
if should_sort_imports:
formatted_helper_code = sort_imports(formatted_helper_code)
new_helper_code[module_abspath] = formatted_helper_code

return new_code, new_helper_code
whole_file_content = path.read_text(encoding="utf8")
wrapper: cst.metadata.MetadataWrapper | None = None
try:
wrapper = cst.metadata.MetadataWrapper(cst.parse_module(whole_file_content))
except cst.ParserSyntaxError as e:
logger.error(f"Syntax error detected, aborting reformatting.")
return original_code, {}

visitor = CodeRangeFunctionVisitor(target_function_name=opt_func_name)
wrapper.visit(visitor)

lines = whole_file_content.splitlines(keepends=True)
if visitor.start_line == None:
logger.error(f"Could not find function {opt_func_name} in {path}, aborting reformatting.")
return original_code, {}
else:
opt_func_source_lines = lines[visitor.start_line-1:visitor.end_line]

# fix opt func identation
first_line = opt_func_source_lines[0]
first_line_indent = len(first_line) - len(first_line.lstrip()) # number of spaces before the first character
opt_func_source_lines[0] = opt_func_source_lines[0][first_line_indent:] # remove first line ident, so when we save the function code into a temp file, we don't get syntax errors

with tempfile.NamedTemporaryFile(mode='w+', delete=True) as f:
f.write("".join(opt_func_source_lines))
f.flush()
tmp_file = Path(f.name)
formatted_func = format_code(self.args.formatter_cmds, tmp_file)
# apply the identation back to all lines of the formatted function
formatted_lines = formatted_func.splitlines(keepends=True)
for i in range(len(formatted_lines)):
formatted_lines[i] = (" " * first_line_indent) + formatted_lines[i]

# replace the unformatted code with formatted ones
new_code = (
"".join(lines[:visitor.start_line-1]) +
"".join(formatted_lines) +
"".join(lines[visitor.end_line:])
)
if should_sort_imports:
new_code = sort_imports(new_code)

new_helper_code: dict[Path, str] = {}
helper_functions_paths = {hf.file_path for hf in helper_functions}
for module_abspath in helper_functions_paths:
formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath)
if should_sort_imports:
formatted_helper_code = sort_imports(formatted_helper_code)
new_helper_code[module_abspath] = formatted_helper_code

return new_code, new_helper_code

def replace_function_and_helpers_with_optimized_code(
self, code_context: CodeOptimizationContext, optimized_code: str
Expand Down
Loading
Loading