@@ -96,6 +96,20 @@ def is_diff_line(line: str) -> bool:
9696 return len (diff_lines )
9797
9898
99+ def format_generated_code (generated_test_source : str , formatter_cmds : Union [list [str ], None ] = None ) -> str :
100+ formatter_name = formatter_cmds [0 ].lower () if formatter_cmds else "disabled"
101+ if formatter_name == "disabled" :
102+ return re .sub (r"\n{2,}" , "\n \n " , generated_test_source )
103+ # try running formatter, if nothing changes (could be due to formatting failing or no actual formatting needed)
104+ original_temp , test_dir_str , exit_on_failure = None , None , True
105+ formatted_temp , formatted_code , changed = apply_formatter_cmds (
106+ formatter_cmds , original_temp , test_dir_str , print_status = False , exit_on_failure = exit_on_failure
107+ )
108+ if not changed :
109+ return re .sub (r"\n{2,}" , "\n \n " , formatted_code )
110+ return formatted_code
111+
112+
99113def format_code (
100114 formatter_cmds : list [str ],
101115 path : Union [str , Path ],
@@ -120,7 +134,7 @@ def format_code(
120134 original_code_lines = len (original_code .split ("\n " ))
121135
122136 if check_diff and original_code_lines > 50 :
123- # we dont' count the formatting diff for the optimized function as it should be well-formatted
137+ # we don't count the formatting diff for the optimized function as it should be well-formatted
124138 original_code_without_opfunc = original_code .replace (optimized_code , "" )
125139
126140 original_temp = Path (test_dir_str ) / "original_temp.py"
0 commit comments