Skip to content

Commit a1510a3

Browse files
enhancements
1 parent 64f2dd9 commit a1510a3

File tree

2 files changed

+28
-21
lines changed

2 files changed

+28
-21
lines changed

codeflash/code_utils/formatter.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -104,32 +104,39 @@ def format_code(
104104
formatter_cmds: list[str],
105105
path: Union[str, Path],
106106
optimized_function: str = "",
107+
check_diff: bool = False, # noqa
107108
print_status: bool = True, # noqa
108109
) -> str:
109110
with tempfile.TemporaryDirectory() as test_dir_str:
110-
max_diff_lines = 100
111-
112111
if isinstance(path, str):
113112
path = Path(path)
114113

115114
original_code = path.read_text(encoding="utf8")
116-
# we dont' count the formatting diff for the optimized function as it should be well-formatted
117-
original_code_without_opfunc = original_code.replace(optimized_function, "")
118-
119-
original_temp = Path(test_dir_str) / "original_temp.py"
120-
original_temp.write_text(original_code_without_opfunc, encoding="utf8")
121-
122-
formatted_temp, formatted_code = apply_formatter_cmds(
123-
formatter_cmds, original_temp, test_dir_str, print_status=False
124-
)
125-
126-
diff_output = generate_unified_diff(
127-
original_code_without_opfunc, formatted_code, from_file=str(original_temp), to_file=str(formatted_temp)
128-
)
129-
diff_lines_count = get_diff_lines_count(diff_output)
130-
if diff_lines_count > max_diff_lines:
131-
logger.debug(f"Skipping formatting {path}: {diff_lines_count} lines would change (max: {max_diff_lines})")
132-
return original_code
115+
original_code_lines = len(original_code.split("\n"))
116+
117+
if check_diff and original_code_lines > 50:
118+
# we dont' count the formatting diff for the optimized function as it should be well-formatted
119+
original_code_without_opfunc = original_code.replace(optimized_function, "")
120+
121+
original_temp = Path(test_dir_str) / "original_temp.py"
122+
original_temp.write_text(original_code_without_opfunc, encoding="utf8")
123+
124+
formatted_temp, formatted_code = apply_formatter_cmds(
125+
formatter_cmds, original_temp, test_dir_str, print_status=False
126+
)
127+
128+
diff_output = generate_unified_diff(
129+
original_code_without_opfunc, formatted_code, from_file=str(original_temp), to_file=str(formatted_temp)
130+
)
131+
diff_lines_count = get_diff_lines_count(diff_output)
132+
133+
max_diff_lines = min(int(original_code_lines * 0.3), 50)
134+
135+
if diff_lines_count > max_diff_lines and max_diff_lines != -1:
136+
logger.debug(
137+
f"Skipping formatting {path}: {diff_lines_count} lines would change (max: {max_diff_lines})"
138+
)
139+
return original_code
133140

134141
_, formatted_code = apply_formatter_cmds(formatter_cmds, path, test_dir_str=None, print_status=print_status)
135142
logger.debug(f"Formatted {path} with commands: {formatter_cmds}")

codeflash/optimization/function_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ def reformat_code_and_helpers(
612612
if should_sort_imports and isort.code(original_code) != original_code:
613613
should_sort_imports = False
614614

615-
new_code = format_code(self.args.formatter_cmds, path, optimized_function=optimized_function)
615+
new_code = format_code(self.args.formatter_cmds, path, optimized_function=optimized_function, check_diff=True)
616616
if should_sort_imports:
617617
new_code = sort_imports(new_code)
618618

@@ -621,7 +621,7 @@ def reformat_code_and_helpers(
621621
module_abspath = hp.file_path
622622
hp_source_code = hp.source_code
623623
formatted_helper_code = format_code(
624-
self.args.formatter_cmds, module_abspath, optimized_function=hp_source_code
624+
self.args.formatter_cmds, module_abspath, optimized_function=hp_source_code, check_diff=True
625625
)
626626
if should_sort_imports:
627627
formatted_helper_code = sort_imports(formatted_helper_code)

0 commit comments

Comments
 (0)