Skip to content

Commit c75bbf6

Browse files
check large diffs with black, and skipp formatting in such case (after optimizing)
1 parent 47f6c02 commit c75bbf6

File tree

4 files changed

+300
-2
lines changed

4 files changed

+300
-2
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import os
2+
3+
class BadlyFormattedClass(object):
4+
def __init__(
5+
self,
6+
name,
7+
age= None,
8+
email= None,
9+
phone=None,
10+
address=None,
11+
city=None,
12+
state=None,
13+
zip_code=None,
14+
):
15+
self.name = name
16+
self.age = age
17+
self.email = email
18+
self.phone = phone
19+
self. address = address
20+
self.city = city
21+
self.state = state
22+
self.zip_code = zip_code
23+
self.data = {"name": name, "age": age, "email": email}
24+
25+
def get_info(self):
26+
return f"Name: {self.name}, Age: {self.age}"
27+
28+
def update_data(self, **kwargs):
29+
for key, value in kwargs.items():
30+
if hasattr(self, key):
31+
setattr(self, key, value)
32+
self.data.update(kwargs)
33+
34+
35+
def process_data(
36+
data_list, filter_func=None, transform_func=None, sort_key=None, reverse=False
37+
):
38+
if not data_list:
39+
return []
40+
if filter_func:
41+
data_list = [ item for item in data_list if filter_func(item)]
42+
if transform_func:
43+
data_list = [transform_func(item) for item in data_list]
44+
if sort_key:
45+
data_list = sorted(data_list, key=sort_key, reverse=reverse)
46+
return data_list
47+
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import os,sys,json,datetime,math,random;import requests;from collections import defaultdict,OrderedDict
2+
from typing import List,Dict,Optional,Union,Tuple,Any;import numpy as np;import pandas as pd
3+
4+
# This is a poorly formatted Python file with many style violations
5+
6+
class BadlyFormattedClass( object ):
7+
def __init__(self,name,age=None,email=None,phone=None,address=None,city=None,state=None,zip_code=None):
8+
self.name=name;self.age=age;self.email=email;self.phone=phone
9+
self.address=address;self.city=city;self.state=state;self.zip_code=zip_code
10+
self.data={"name":name,"age":age,"email":email}
11+
12+
def get_info(self ):
13+
return f"Name: {self.name}, Age: {self.age}"
14+
15+
def update_data(self,**kwargs):
16+
for key,value in kwargs.items():
17+
if hasattr(self,key):setattr(self,key,value)
18+
self.data.update(kwargs)
19+
20+
def process_data(data_list,filter_func=None,transform_func=None,sort_key=None,reverse=False):
21+
if not data_list:return[]
22+
if filter_func:data_list=[item for item in data_list if filter_func(item)]
23+
if transform_func:data_list=[transform_func(item)for item in data_list]
24+
if sort_key:data_list=sorted(data_list,key=sort_key,reverse=reverse)
25+
return data_list
26+
27+
def calculate_statistics(numbers):
28+
if not numbers:return None
29+
mean=sum(numbers)/len(numbers); median=sorted(numbers)[len(numbers)//2]
30+
variance=sum((x-mean)**2 for x in numbers)/len(numbers);std_dev=math.sqrt(variance)
31+
return {"mean":mean,"median":median,"variance":variance,"std_dev":std_dev,"min":min(numbers),"max":max(numbers)}
32+
33+
def complex_nested_function(x,y,z):
34+
def inner_function_1(a,b):
35+
def deeply_nested(c,d):
36+
return c*d+a*b
37+
return deeply_nested(a+1,b-1)+deeply_nested(a-1,b+1)
38+
def inner_function_2 (a,b,c):
39+
result=[]
40+
for i in range(a):
41+
for j in range(b):
42+
for k in range(c):
43+
if i*j*k>0:result.append(i*j*k)
44+
elif i+j+k==0:result.append(-1)
45+
else :result.append(0)
46+
return result
47+
return inner_function_1(x,y)+sum(inner_function_2(x,y,z))
48+
49+
# Long lines and poor dictionary formatting
50+
user_data={"users":[{"id":1,"name":"John Doe","email":"[email protected]","preferences":{"theme":"dark","notifications":True,"language":"en"},"metadata":{"created_at":"2023-01-01","last_login":"2024-01-01","login_count":150}},{"id":2,"name":"Jane Smith","email":"[email protected]","preferences":{"theme":"light","notifications":False,"language":"es"},"metadata":{"created_at":"2023-02-15","last_login":"2024-01-15","login_count":89}}]}
51+
52+
# Poor list formatting and string concatenation
53+
long_list_of_items=['item_1','item_2','item_3','item_4','item_5','item_6','item_7','item_8','item_9','item_10','item_11','item_12','item_13','item_14','item_15','item_16','item_17','item_18','item_19','item_20']
54+
55+
def generate_report(data,include_stats=True,include_charts=False,format_type='json',output_file=None):
56+
if not data:raise ValueError("Data cannot be empty")
57+
report={'timestamp':datetime.datetime.now().isoformat(),'data_count':len(data),'summary':{}}
58+
59+
# Bad formatting in loops and conditionals
60+
for i,item in enumerate(data):
61+
if isinstance(item,dict):
62+
for key,value in item.items():
63+
if key not in report['summary']:report['summary'][key]=[]
64+
report['summary'][key].append(value)
65+
elif isinstance(item,(int,float)):
66+
if 'numbers' not in report['summary']:report['summary']['numbers']=[]
67+
report['summary']['numbers'].append(item)
68+
else:
69+
if 'other' not in report['summary']:report['summary']['other']=[]
70+
report['summary']['other'].append(str(item))
71+
72+
if include_stats and 'numbers' in report['summary']:
73+
numbers=report['summary']['numbers']
74+
report['statistics']=calculate_statistics(numbers)
75+
76+
# Long conditional chain with poor formatting
77+
if format_type=='json':result=json.dumps(report,indent=None,separators=(',',':'))
78+
elif format_type=='pretty_json':result=json.dumps(report,indent=2)
79+
elif format_type=='string':result=str(report)
80+
else:result=report
81+
82+
if output_file:
83+
with open(output_file,'w')as f:f.write(result if isinstance(result,str)else json.dumps(result))
84+
85+
return result
86+
87+
class DataProcessor ( BadlyFormattedClass ) :
88+
def __init__(self,data_source,config=None,debug=False):
89+
super().__init__("DataProcessor")
90+
self.data_source=data_source;self.config=config or{};self.debug=debug
91+
self.processed_data=[];self.errors=[];self.warnings=[]
92+
93+
def load_data ( self ) :
94+
try:
95+
if isinstance(self.data_source,str):
96+
if self.data_source.endswith('.json'):
97+
with open(self.data_source,'r')as f:data=json.load(f)
98+
elif self.data_source.endswith('.csv'):data=pd.read_csv(self.data_source).to_dict('records')
99+
else:raise ValueError(f"Unsupported file type: {self.data_source}")
100+
elif isinstance(self.data_source,list):data=self.data_source
101+
else:data=[self.data_source]
102+
return data
103+
except Exception as e:
104+
self.errors.append(str(e));return[]
105+
106+
def validate_data(self,data):
107+
valid_items=[];invalid_items=[]
108+
for item in data:
109+
if isinstance(item,dict)and'id'in item and'name'in item:valid_items.append(item)
110+
else:invalid_items.append(item)
111+
if invalid_items:self.warnings.append(f"Found {len(invalid_items)} invalid items")
112+
return valid_items
113+
114+
def process(self):
115+
data=self.load_data()
116+
if not data:return{"success":False,"error":"No data loaded"}
117+
118+
validated_data=self.validate_data(data)
119+
processed_result=process_data(validated_data,
120+
filter_func=lambda x:x.get('active',True),
121+
transform_func=lambda x:{**x,'processed_at':datetime.datetime.now().isoformat()},
122+
sort_key=lambda x:x.get('name',''))
123+
124+
self.processed_data=processed_result
125+
return{"success":True,"count":len(processed_result),"data":processed_result}
126+
if __name__=="__main__":
127+
sample_data=[{"id":1,"name":"Alice","active":True},{"id":2,"name":"Bob","active":False},{"id":3,"name":"Charlie","active":True}]
128+
129+
processor=DataProcessor(sample_data,config={"debug":True})
130+
result=processor.process()
131+
132+
if result["success"]:
133+
print(f"Successfully processed {result['count']} items")
134+
for item in result["data"][:3]:print(f"- {item['name']} (ID: {item['id']})")
135+
else:print(f"Processing failed: {result.get('error','Unknown error')}")
136+
137+
# Generate report with poor formatting
138+
report=generate_report(sample_data,include_stats=True,format_type='pretty_json')
139+
print("Generated report:",report[:100]+"..."if len(report)>100 else report)
140+
141+
# Complex calculation with poor spacing
142+
numbers=[random.randint(1,100)for _ in range(50)]
143+
stats=calculate_statistics(numbers)
144+
complex_result=complex_nested_function(5,3,2)
145+
146+
print(f"Statistics: mean={stats['mean']:.2f}, std_dev={stats['std_dev']:.2f}")
147+
print(f"Complex calculation result: {complex_result}")

codeflash/code_utils/formatter.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,49 @@
1313
from pathlib import Path
1414

1515

16+
def should_format_file(filepath, max_lines_changed=50):
17+
try:
18+
# check if black is installed
19+
subprocess.run(['black', '--version'], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
20+
21+
result = subprocess.run(
22+
['black', '--diff', filepath],
23+
capture_output=True,
24+
text=True
25+
)
26+
27+
if result.returncode == 0 and not result.stdout:
28+
return False
29+
30+
diff_lines = [line for line in result.stdout.split('\n')
31+
if line.startswith(('+', '-')) and not line.startswith(('+++', '---'))]
32+
33+
changes_count = len(diff_lines)
34+
35+
if changes_count > max_lines_changed:
36+
logger.debug(f"Skipping {filepath}: {changes_count} lines would change (max: {max_lines_changed})")
37+
return False
38+
39+
return True
40+
41+
except subprocess.CalledProcessError:
42+
logger.warning(f"black command failed for {filepath}")
43+
return False
44+
except FileNotFoundError:
45+
logger.warning("black is not installed. Skipping formatting check.")
46+
return False
47+
48+
49+
1650
def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True) -> str: # noqa
1751
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
1852
formatter_name = formatter_cmds[0].lower()
1953
if not path.exists():
2054
msg = f"File {path} does not exist. Cannot format the file."
2155
raise FileNotFoundError(msg)
22-
if formatter_name == "disabled":
56+
if formatter_name == "disabled" or not should_format_file(path):
2357
return path.read_text(encoding="utf8")
58+
2459
file_token = "$file" # noqa: S105
2560
for command in formatter_cmds:
2661
formatter_cmd_list = shlex.split(command, posix=os.name != "nt")
@@ -29,7 +64,7 @@ def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True
2964
result = subprocess.run(formatter_cmd_list, capture_output=True, check=False)
3065
if result.returncode == 0:
3166
if print_status:
32-
console.rule(f"Formatted Successfully with: {formatter_name.replace('$file', path.name)}")
67+
console.rule(f"Formatted Successfully with: {command.replace('$file', path.name)}")
3368
else:
3469
logger.error(f"Failed to format code with {' '.join(formatter_cmd_list)}")
3570
except FileNotFoundError as e:

tests/test_formatter.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1+
import argparse
12
import os
23
import tempfile
34
from pathlib import Path
45

56
import pytest
7+
import shutil
68

79
from codeflash.code_utils.config_parser import parse_config_file
810
from codeflash.code_utils.formatter import format_code, sort_imports
911

12+
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
13+
from codeflash.optimization.function_optimizer import FunctionOptimizer
14+
from codeflash.verification.verification_utils import TestConfig
1015

1116
def test_remove_duplicate_imports():
1217
"""Test that duplicate imports are removed when should_sort_imports is True."""
@@ -209,3 +214,67 @@ def foo():
209214
tmp_path = tmp.name
210215
with pytest.raises(FileNotFoundError):
211216
format_code(formatter_cmds=["exit 1"], path=Path(tmp_path))
217+
218+
219+
def _run_formatting_test(source_filename: str, should_content_change: bool):
220+
"""Helper function to run formatting tests with common setup and teardown."""
221+
with tempfile.TemporaryDirectory() as test_dir_str:
222+
test_dir = Path(test_dir_str)
223+
this_file = Path(__file__).resolve()
224+
repo_root_dir = this_file.parent.parent
225+
source_file = repo_root_dir / "code_to_optimize" / source_filename
226+
227+
original = source_file.read_text()
228+
target_path = test_dir / "target.py"
229+
230+
shutil.copy2(source_file, target_path)
231+
232+
function_to_optimize = FunctionToOptimize(
233+
function_name="process_data",
234+
parents=[],
235+
file_path=target_path
236+
)
237+
238+
test_cfg = TestConfig(
239+
tests_root=test_dir,
240+
project_root_path=test_dir,
241+
test_framework="pytest",
242+
tests_project_rootdir=test_dir,
243+
)
244+
245+
args = argparse.Namespace(
246+
disable_imports_sorting=False,
247+
formatter_cmds=[
248+
"ruff check --exit-zero --fix $file",
249+
"ruff format $file"
250+
],
251+
)
252+
253+
optimizer = FunctionOptimizer(
254+
function_to_optimize=function_to_optimize,
255+
test_cfg=test_cfg,
256+
args=args,
257+
)
258+
259+
optimizer.reformat_code_and_helpers(
260+
helper_functions=[],
261+
path=target_path,
262+
original_code=optimizer.function_to_optimize_source_code,
263+
)
264+
265+
content = target_path.read_text()
266+
267+
if should_content_change:
268+
assert content != original, f"Expected content to change for {source_filename}"
269+
else:
270+
assert content == original, f"Expected content to remain unchanged for {source_filename}"
271+
272+
273+
def test_formatting_file_with_many_diffs():
274+
"""Test that files with many formatting errors are skipped (content unchanged)."""
275+
_run_formatting_test("many_formatting_errors.py", should_content_change=False)
276+
277+
278+
def test_formatting_file_with_few_diffs():
279+
"""Test that files with few formatting errors are formatted (content changed)."""
280+
_run_formatting_test("few_formatting_errors.py", should_content_change=True)

0 commit comments

Comments
 (0)