Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
c75bbf6
check large diffs with black, and skipp formatting in such case (afte…
mohammedahmed18 Jun 3, 2025
5cd13ad
new line
mohammedahmed18 Jun 3, 2025
1522227
better log messages
mohammedahmed18 Jun 3, 2025
d3ca1cb
remove unnecessary check
mohammedahmed18 Jun 3, 2025
dcb084a
new line
mohammedahmed18 Jun 3, 2025
689a2d9
remove unused comment
mohammedahmed18 Jun 3, 2025
44c0f85
the max lines for formatting changes to 100
mohammedahmed18 Jun 3, 2025
73ef518
refactoring
mohammedahmed18 Jun 3, 2025
a5343fd
refactoring and improvements
mohammedahmed18 Jun 3, 2025
395855d
added black as dev dependency
mohammedahmed18 Jun 3, 2025
822d6cc
made some refactor changes that codeflash suggested
mohammedahmed18 Jun 3, 2025
ce15022
remove unused function
mohammedahmed18 Jun 3, 2025
d2a8711
formatting & using internal black dep
mohammedahmed18 Jun 3, 2025
f46b368
fix black import issue
mohammedahmed18 Jun 4, 2025
6504cc4
handle formatting files with no formatting issues
mohammedahmed18 Jun 4, 2025
aed490d
Merge branch 'main' into skip-formatting-for-large-diffs
Saga4 Jun 4, 2025
82a4ee1
use user pre-defined formatting commands, instead of using black
mohammedahmed18 Jun 4, 2025
90014bd
Merge branch 'skip-formatting-for-large-diffs' of github.com:codeflas…
mohammedahmed18 Jun 4, 2025
caeda49
make sure format_code recieves file path as path type not as str
mohammedahmed18 Jun 4, 2025
6967fcb
formatting and linting
mohammedahmed18 Jun 4, 2025
8248c8e
typo
mohammedahmed18 Jun 4, 2025
15aacdb
revert lock file changes
mohammedahmed18 Jun 4, 2025
c24fc90
remove comment
mohammedahmed18 Jun 4, 2025
b48e9e6
pass helper functions source code to the formatter for diff checking
mohammedahmed18 Jun 5, 2025
93070a9
Merge branch 'main' of github.com:codeflash-ai/codeflash into skip-fo…
mohammedahmed18 Jun 6, 2025
64f2dd9
more unit tests
mohammedahmed18 Jun 6, 2025
a1510a3
enhancements
mohammedahmed18 Jun 6, 2025
6f97004
Merge branch 'main' into skip-formatting-for-large-diffs
Saga4 Jun 10, 2025
6cb8469
Update formatter.py
Saga4 Jun 10, 2025
94e64d3
Update formatter.py
Saga4 Jun 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions code_to_optimize/few_formatting_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os

class UnformattedExampleClass(object):
def __init__(
self,
name,
age= None,
email= None,
phone=None,
address=None,
city=None,
state=None,
zip_code=None,
):
self.name = name
self.age = age
self.email = email
self.phone = phone
self. address = address
self.city = city
self.state = state
self.zip_code = zip_code
self.data = {"name": name, "age": age, "email": email}

def get_info(self):
return f"Name: {self.name}, Age: {self.age}"

def update_data(self, **kwargs):
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
self.data.update(kwargs)


def process_data(
data_list, filter_func=None, transform_func=None, sort_key=None, reverse=False
):
if not data_list:
return []
if filter_func:
data_list = [ item for item in data_list if filter_func(item)]
if transform_func:
data_list = [transform_func(item) for item in data_list]
if sort_key:
data_list = sorted(data_list, key=sort_key, reverse=reverse)
return data_list

147 changes: 147 additions & 0 deletions code_to_optimize/many_formatting_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import os,sys,json,datetime,math,random;import requests;from collections import defaultdict,OrderedDict
from typing import List,Dict,Optional,Union,Tuple,Any;import numpy as np;import pandas as pd

# This is a poorly formatted Python file with many style violations

class UnformattedExampleClass( object ):
def __init__(self,name,age=None,email=None,phone=None,address=None,city=None,state=None,zip_code=None):
self.name=name;self.age=age;self.email=email;self.phone=phone
self.address=address;self.city=city;self.state=state;self.zip_code=zip_code
self.data={"name":name,"age":age,"email":email}

def get_info(self ):
return f"Name: {self.name}, Age: {self.age}"

def update_data(self,**kwargs):
for key,value in kwargs.items():
if hasattr(self,key):setattr(self,key,value)
self.data.update(kwargs)

def process_data(data_list,filter_func=None,transform_func=None,sort_key=None,reverse=False):
if not data_list:return[]
if filter_func:data_list=[item for item in data_list if filter_func(item)]
if transform_func:data_list=[transform_func(item)for item in data_list]
if sort_key:data_list=sorted(data_list,key=sort_key,reverse=reverse)
return data_list

def calculate_statistics(numbers):
if not numbers:return None
mean=sum(numbers)/len(numbers); median=sorted(numbers)[len(numbers)//2]
variance=sum((x-mean)**2 for x in numbers)/len(numbers);std_dev=math.sqrt(variance)
return {"mean":mean,"median":median,"variance":variance,"std_dev":std_dev,"min":min(numbers),"max":max(numbers)}

def complex_nested_function(x,y,z):
def inner_function_1(a,b):
def deeply_nested(c,d):
return c*d+a*b
return deeply_nested(a+1,b-1)+deeply_nested(a-1,b+1)
def inner_function_2 (a,b,c):
result=[]
for i in range(a):
for j in range(b):
for k in range(c):
if i*j*k>0:result.append(i*j*k)
elif i+j+k==0:result.append(-1)
else :result.append(0)
return result
return inner_function_1(x,y)+sum(inner_function_2(x,y,z))

# Long lines and poor dictionary formatting
user_data={"users":[{"id":1,"name":"John Doe","email":"[email protected]","preferences":{"theme":"dark","notifications":True,"language":"en"},"metadata":{"created_at":"2023-01-01","last_login":"2024-01-01","login_count":150}},{"id":2,"name":"Jane Smith","email":"[email protected]","preferences":{"theme":"light","notifications":False,"language":"es"},"metadata":{"created_at":"2023-02-15","last_login":"2024-01-15","login_count":89}}]}

# Poor list formatting and string concatenation
long_list_of_items=['item_1','item_2','item_3','item_4','item_5','item_6','item_7','item_8','item_9','item_10','item_11','item_12','item_13','item_14','item_15','item_16','item_17','item_18','item_19','item_20']

def generate_report(data,include_stats=True,include_charts=False,format_type='json',output_file=None):
if not data:raise ValueError("Data cannot be empty")
report={'timestamp':datetime.datetime.now().isoformat(),'data_count':len(data),'summary':{}}

# Bad formatting in loops and conditionals
for i,item in enumerate(data):
if isinstance(item,dict):
for key,value in item.items():
if key not in report['summary']:report['summary'][key]=[]
report['summary'][key].append(value)
elif isinstance(item,(int,float)):
if 'numbers' not in report['summary']:report['summary']['numbers']=[]
report['summary']['numbers'].append(item)
else:
if 'other' not in report['summary']:report['summary']['other']=[]
report['summary']['other'].append(str(item))

if include_stats and 'numbers' in report['summary']:
numbers=report['summary']['numbers']
report['statistics']=calculate_statistics(numbers)

# Long conditional chain with poor formatting
if format_type=='json':result=json.dumps(report,indent=None,separators=(',',':'))
elif format_type=='pretty_json':result=json.dumps(report,indent=2)
elif format_type=='string':result=str(report)
else:result=report

if output_file:
with open(output_file,'w')as f:f.write(result if isinstance(result,str)else json.dumps(result))

return result

class DataProcessor ( UnformattedExampleClass ) :
def __init__(self,data_source,config=None,debug=False):
super().__init__("DataProcessor")
self.data_source=data_source;self.config=config or{};self.debug=debug
self.processed_data=[];self.errors=[];self.warnings=[]

def load_data ( self ) :
try:
if isinstance(self.data_source,str):
if self.data_source.endswith('.json'):
with open(self.data_source,'r')as f:data=json.load(f)
elif self.data_source.endswith('.csv'):data=pd.read_csv(self.data_source).to_dict('records')
else:raise ValueError(f"Unsupported file type: {self.data_source}")
elif isinstance(self.data_source,list):data=self.data_source
else:data=[self.data_source]
return data
except Exception as e:
self.errors.append(str(e));return[]

def validate_data(self,data):
valid_items=[];invalid_items=[]
for item in data:
if isinstance(item,dict)and'id'in item and'name'in item:valid_items.append(item)
else:invalid_items.append(item)
if invalid_items:self.warnings.append(f"Found {len(invalid_items)} invalid items")
return valid_items

def process(self):
data=self.load_data()
if not data:return{"success":False,"error":"No data loaded"}

validated_data=self.validate_data(data)
processed_result=process_data(validated_data,
filter_func=lambda x:x.get('active',True),
transform_func=lambda x:{**x,'processed_at':datetime.datetime.now().isoformat()},
sort_key=lambda x:x.get('name',''))

self.processed_data=processed_result
return{"success":True,"count":len(processed_result),"data":processed_result}
if __name__=="__main__":
sample_data=[{"id":1,"name":"Alice","active":True},{"id":2,"name":"Bob","active":False},{"id":3,"name":"Charlie","active":True}]

processor=DataProcessor(sample_data,config={"debug":True})
result=processor.process()

if result["success"]:
print(f"Successfully processed {result['count']} items")
for item in result["data"][:3]:print(f"- {item['name']} (ID: {item['id']})")
else:print(f"Processing failed: {result.get('error','Unknown error')}")

# Generate report with poor formatting
report=generate_report(sample_data,include_stats=True,format_type='pretty_json')
print("Generated report:",report[:100]+"..."if len(report)>100 else report)

# Complex calculation with poor spacing
numbers=[random.randint(1,100)for _ in range(50)]
stats=calculate_statistics(numbers)
complex_result=complex_nested_function(5,3,2)

print(f"Statistics: mean={stats['mean']:.2f}, std_dev={stats['std_dev']:.2f}")
print(f"Complex calculation result: {complex_result}")
61 changes: 58 additions & 3 deletions codeflash/code_utils/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import shlex
import subprocess
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

import isort

Expand All @@ -12,15 +12,70 @@
if TYPE_CHECKING:
from pathlib import Path

def get_diff_lines_output_by_black(filepath: str) -> Optional[str]:
try:
subprocess.run(['black', '--version'], check=True,
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
result = subprocess.run(
['black', '--diff', filepath],
capture_output=True,
text=True
)
return result.stdout.strip() if result.stdout else None
except (FileNotFoundError):
return None


def get_diff_lines_output_by_ruff(filepath: str) -> Optional[str]:
try:
subprocess.run(['ruff', '--version'], check=True,
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
result = subprocess.run(
['ruff', "format", '--diff', filepath],
capture_output=True,
text=True
)
return result.stdout.strip() if result.stdout else None
except (FileNotFoundError):
return None


def get_diff_lines_count(diff_output: str) -> int:
diff_lines = [line for line in diff_output.split('\n')
if line.startswith(('+', '-')) and not line.startswith(('+++', '---'))]
return len(diff_lines)

def is_safe_to_format(filepath: str, max_diff_lines: int = 100) -> bool:
diff_changes_stdout = None

diff_changes_stdout = get_diff_lines_output_by_black(filepath)

if diff_changes_stdout is None:
logger.warning(f"black formatter not found, trying ruff instead...")
diff_changes_stdout = get_diff_lines_output_by_ruff(filepath)
if diff_changes_stdout is None:
msg = f"Both ruff, black formatters not found, skipping formatting diff check."
logger.warning(msg)
raise FileNotFoundError(msg)

diff_lines_count = get_diff_lines_count(diff_changes_stdout)

if diff_lines_count > max_diff_lines:
logger.debug(f"Skipping {filepath}: {diff_lines_count} lines would change (max: {max_diff_lines})")
return False
else:
return True


def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True) -> str: # noqa
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
formatter_name = formatter_cmds[0].lower()
if not path.exists():
msg = f"File {path} does not exist. Cannot format the file."
raise FileNotFoundError(msg)
if formatter_name == "disabled":
if formatter_name == "disabled" or not is_safe_to_format(path): # few -> False, large -> True
return path.read_text(encoding="utf8")

file_token = "$file" # noqa: S105
for command in formatter_cmds:
formatter_cmd_list = shlex.split(command, posix=os.name != "nt")
Expand All @@ -29,7 +84,7 @@ def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True
result = subprocess.run(formatter_cmd_list, capture_output=True, check=False)
if result.returncode == 0:
if print_status:
console.rule(f"Formatted Successfully with: {formatter_name.replace('$file', path.name)}")
console.rule(f"Formatted Successfully with: {command.replace('$file', path.name)}")
else:
logger.error(f"Failed to format code with {' '.join(formatter_cmd_list)}")
except FileNotFoundError as e:
Expand Down
75 changes: 75 additions & 0 deletions tests/test_formatter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import argparse
import os
import tempfile
from pathlib import Path

import pytest
import shutil

from codeflash.code_utils.config_parser import parse_config_file
from codeflash.code_utils.formatter import format_code, sort_imports

from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.verification_utils import TestConfig

def test_remove_duplicate_imports():
"""Test that duplicate imports are removed when should_sort_imports is True."""
Expand Down Expand Up @@ -209,3 +214,73 @@ def foo():
tmp_path = tmp.name
with pytest.raises(FileNotFoundError):
format_code(formatter_cmds=["exit 1"], path=Path(tmp_path))


def _run_formatting_test(source_filename: str, should_content_change: bool):
with tempfile.TemporaryDirectory() as test_dir_str:
test_dir = Path(test_dir_str)
this_file = Path(__file__).resolve()
repo_root_dir = this_file.parent.parent
source_file = repo_root_dir / "code_to_optimize" / source_filename

original = source_file.read_text()
target_path = test_dir / "target.py"

shutil.copy2(source_file, target_path)

function_to_optimize = FunctionToOptimize(
function_name="process_data",
parents=[],
file_path=target_path
)

test_cfg = TestConfig(
tests_root=test_dir,
project_root_path=test_dir,
test_framework="pytest",
tests_project_rootdir=test_dir,
)

args = argparse.Namespace(
disable_imports_sorting=False,
formatter_cmds=[
"ruff check --exit-zero --fix $file",
"ruff format $file"
],
)

optimizer = FunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
args=args,
)

optimizer.reformat_code_and_helpers(
helper_functions=[],
path=target_path,
original_code=optimizer.function_to_optimize_source_code,
)

content = target_path.read_text()

if should_content_change:
assert content != original, f"Expected content to change for {source_filename}"
else:
assert content == original, f"Expected content to remain unchanged for {source_filename}"

def _ruff_or_black_installed() -> bool:
return shutil.which("black") is not None or shutil.which("ruff") is not None


def test_formatting_file_with_many_diffs():
"""Test that files with many formatting errors are skipped (content unchanged)."""
if not _ruff_or_black_installed():
pytest.skip("Neither black nor ruff is installed, skipping formatting tests.")
_run_formatting_test("many_formatting_errors.py", should_content_change=False)


def test_formatting_file_with_few_diffs():
"""Test that files with few formatting errors are formatted (content changed)."""
if not _ruff_or_black_installed():
pytest.skip("Neither black nor ruff is installed, skipping formatting tests.")
_run_formatting_test("few_formatting_errors.py", should_content_change=True)
Loading