Skip to content

Commit 82a4ee1

Browse files
use user pre-defined formatting commands, instead of using black
1 parent 6504cc4 commit 82a4ee1

File tree

6 files changed

+106
-130
lines changed

6 files changed

+106
-130
lines changed

code_to_optimize/no_formatting_errors.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,8 @@
1-
import os, sys, json, datetime, math, random
2-
import requests
3-
from collections import defaultdict, OrderedDict
4-
from typing import List, Dict, Optional, Union, Tuple, Any
5-
import numpy as np
6-
import pandas as pd
1+
import math
72

8-
# This is a poorly formatted Python file with many style violations
93

10-
11-
class UnformattedExampleClass(object):
12-
def __init__(
13-
self,
14-
name,
15-
age=None,
16-
email=None,
17-
phone=None,
18-
address=None,
19-
city=None,
20-
state=None,
21-
zip_code=None,
22-
):
4+
class UnformattedExampleClass:
5+
def __init__(self, name, age=None, email=None, phone=None, address=None, city=None, state=None, zip_code=None):
236
self.name = name
247
self.age = age
258
self.email = email
@@ -40,9 +23,7 @@ def update_data(self, **kwargs):
4023
self.data.update(kwargs)
4124

4225

43-
def process_data(
44-
data_list, filter_func=None, transform_func=None, sort_key=None, reverse=False
45-
):
26+
def process_data(data_list, filter_func=None, transform_func=None, sort_key=None, reverse=False):
4627
if not data_list:
4728
return []
4829
if filter_func:

codeflash/code_utils/formatter.py

Lines changed: 84 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,73 @@
11
from __future__ import annotations
22

3+
import difflib
34
import os
5+
import re
46
import shlex
7+
import shutil
58
import subprocess
6-
from typing import TYPE_CHECKING, Optional
9+
import tempfile
10+
from pathlib import Path
11+
from typing import Optional
712

813
import isort
914

1015
from codeflash.cli_cmds.console import console, logger
1116

12-
if TYPE_CHECKING:
13-
from pathlib import Path
1417

18+
def generate_unified_diff(original: str, modified: str, from_file: str, to_file: str) -> str:
19+
line_pattern = re.compile(r"(.*?(?:\r\n|\n|\r|$))")
1520

16-
def get_diff_output_by_black(filepath: str, unformatted_content: str) -> Optional[str]:
17-
try:
18-
from black import Mode, format_file_contents, output, report
21+
def split_lines(text: str) -> list[str]:
22+
lines = [match[0] for match in line_pattern.finditer(text)]
23+
if lines and lines[-1] == "":
24+
lines.pop()
25+
return lines
1926

20-
formatted_content = format_file_contents(src_contents=unformatted_content, fast=True, mode=Mode())
21-
return output.diff(unformatted_content, formatted_content, a_name=filepath, b_name=filepath)
22-
except (ImportError, report.NothingChanged):
23-
return None
27+
original_lines = split_lines(original)
28+
modified_lines = split_lines(modified)
2429

30+
diff_output = []
31+
for line in difflib.unified_diff(original_lines, modified_lines, fromfile=from_file, tofile=to_file, n=5):
32+
if line.endswith("\n"):
33+
diff_output.append(line)
34+
else:
35+
diff_output.append(line + "\n")
36+
diff_output.append("\\ No newline at end of file\n")
2537

26-
def get_diff_lines_count(diff_output: str) -> int:
27-
lines = diff_output.split("\n")
28-
29-
def is_diff_line(line: str) -> bool:
30-
return line.startswith(("+", "-")) and not line.startswith(("+++", "---"))
31-
32-
diff_lines = [line for line in lines if is_diff_line(line)]
33-
return len(diff_lines)
38+
return "".join(diff_output)
3439

3540

36-
def is_safe_to_format(filepath: str, content: str, max_diff_lines: int = 100) -> bool:
37-
diff_changes_str = None
38-
39-
diff_changes_str = get_diff_output_by_black(filepath, unformatted_content=content)
40-
41-
if diff_changes_str is None:
42-
logger.warning("Looks like black formatter not found, make sure it is installed.")
43-
return False
44-
45-
diff_lines_count = get_diff_lines_count(diff_changes_str)
46-
if diff_lines_count > max_diff_lines:
47-
logger.debug(f"Skipping formatting {filepath}: {diff_lines_count} lines would change (max: {max_diff_lines})")
48-
return False
41+
def apply_formatter_cmds(
42+
cmds: list[str],
43+
path: Path,
44+
test_dir_str: Optional[str],
45+
print_status: bool, # noqa
46+
) -> tuple[Path, str]:
47+
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
48+
formatter_name = cmds[0].lower()
49+
should_make_copy = False
50+
file_path = path
4951

50-
return True
52+
if test_dir_str:
53+
should_make_copy = True
54+
file_path = Path(test_dir_str) / "temp.py"
5155

56+
if not cmds or formatter_name == "disabled":
57+
return path, path.read_text(encoding="utf8")
5258

53-
def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True) -> str: # noqa
54-
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
55-
formatter_name = formatter_cmds[0].lower()
5659
if not path.exists():
57-
msg = f"File {path} does not exist. Cannot format the file."
60+
msg = f"File {path} does not exist. Cannot apply formatter commands."
5861
raise FileNotFoundError(msg)
59-
file_content = path.read_text(encoding="utf8")
60-
if formatter_name == "disabled" or not is_safe_to_format(filepath=str(path), content=file_content):
61-
return file_content
62+
63+
if should_make_copy:
64+
shutil.copy2(path, file_path)
6265

6366
file_token = "$file" # noqa: S105
64-
for command in formatter_cmds:
67+
68+
for command in cmds:
6569
formatter_cmd_list = shlex.split(command, posix=os.name != "nt")
66-
formatter_cmd_list = [path.as_posix() if chunk == file_token else chunk for chunk in formatter_cmd_list]
70+
formatter_cmd_list = [file_path.as_posix() if chunk == file_token else chunk for chunk in formatter_cmd_list]
6771
try:
6872
result = subprocess.run(formatter_cmd_list, capture_output=True, check=False)
6973
if result.returncode == 0:
@@ -83,7 +87,45 @@ def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True
8387

8488
raise e from None
8589

86-
return path.read_text(encoding="utf8")
90+
return file_path, file_path.read_text(encoding="utf8")
91+
92+
93+
def get_diff_lines_count(diff_output: str) -> int:
94+
lines = diff_output.split("\n")
95+
96+
def is_diff_line(line: str) -> bool:
97+
return line.startswith(("+", "-")) and not line.startswith(("+++", "---"))
98+
99+
diff_lines = [line for line in lines if is_diff_line(line)]
100+
return len(diff_lines)
101+
102+
103+
def format_code(formatter_cmds: list[str], path: Path, optimized_function: str = "", print_status: bool = True) -> str: # noqa
104+
with tempfile.TemporaryDirectory() as test_dir_str:
105+
max_diff_lines = 100
106+
107+
original_code = path.read_text(encoding="utf8")
108+
# we dont' count the formatting diff for the optimized function as it should be well-formatted (if it's provided)
109+
original_code_without_opfunc = original_code.replace(optimized_function, "")
110+
111+
original_temp = Path(test_dir_str) / "original_temp.py"
112+
original_temp.write_text(original_code_without_opfunc, encoding="utf8")
113+
114+
formatted_temp, formatted_code = apply_formatter_cmds(
115+
formatter_cmds, original_temp, test_dir_str, print_status=False
116+
)
117+
118+
diff_output = generate_unified_diff(
119+
original_code_without_opfunc, formatted_code, from_file=str(original_temp), to_file=str(formatted_temp)
120+
)
121+
diff_lines_count = get_diff_lines_count(diff_output)
122+
if diff_lines_count > max_diff_lines:
123+
logger.debug(f"Skipping formatting {path}: {diff_lines_count} lines would change (max: {max_diff_lines})")
124+
return original_code
125+
126+
_, formatted_code = apply_formatter_cmds(formatter_cmds, path, test_dir_str=None, print_status=print_status)
127+
logger.debug(f"Formatted {path} with commands: {formatter_cmds}")
128+
return formatted_code
87129

88130

89131
def sort_imports(code: str) -> str:

codeflash/optimization/function_optimizer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,10 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
302302
)
303303

304304
new_code, new_helper_code = self.reformat_code_and_helpers(
305-
code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code
305+
code_context.helper_functions,
306+
explanation.file_path,
307+
self.function_to_optimize_source_code,
308+
optimized_function=best_optimization.candidate.source_code,
306309
)
307310

308311
existing_tests = existing_tests_source_for(
@@ -591,18 +594,18 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path,
591594
f.write(helper_code)
592595

593596
def reformat_code_and_helpers(
594-
self, helper_functions: list[FunctionSource], path: Path, original_code: str
597+
self, helper_functions: list[FunctionSource], path: Path, original_code: str, optimized_function: str
595598
) -> tuple[str, dict[Path, str]]:
596599
should_sort_imports = not self.args.disable_imports_sorting
597600
if should_sort_imports and isort.code(original_code) != original_code:
598601
should_sort_imports = False
599602

600-
new_code = format_code(self.args.formatter_cmds, path)
603+
new_code = format_code(self.args.formatter_cmds, path, optimized_function=optimized_function)
601604
if should_sort_imports:
602605
new_code = sort_imports(new_code)
603606

604607
new_helper_code: dict[Path, str] = {}
605-
helper_functions_paths = {hf.file_path for hf in helper_functions}
608+
helper_functions_paths = {hf.source_code for hf in helper_functions}
606609
for module_abspath in helper_functions_paths:
607610
formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath)
608611
if should_sort_imports:

poetry.lock

Lines changed: 1 addition & 60 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ crosshair-tool = ">=0.0.78"
9393
coverage = ">=7.6.4"
9494
line_profiler=">=4.2.0" #this is the minimum version which supports python 3.13
9595
platformdirs = ">=4.3.7"
96-
black = "^25.1.0"
9796
[tool.poetry.group.dev]
9897
optional = true
9998

tests/test_formatter.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,16 @@ def _run_formatting_test(source_filename: str, should_content_change: bool):
263263
helper_functions=[],
264264
path=target_path,
265265
original_code=optimizer.function_to_optimize_source_code,
266+
# this is just for testing, but in practice, this would be an optimized function code and it will be well-formatted
267+
optimized_function=""" def process(self):
268+
data=self.load_data()
269+
if not data:return{"success":False,"error":"No data loaded"}
270+
271+
validated_data=self.validate_data(data)
272+
processed_result=process_data(validated_data,
273+
filter_func=lambda x:x.get('active',True),
274+
transform_func=lambda x:{**x,'processed_at':datetime.datetime.now().isoformat()},
275+
sort_key=lambda x:x.get('name',''))""",
266276
)
267277

268278
if should_content_change:

0 commit comments

Comments
 (0)