@@ -96,19 +96,20 @@ def is_diff_line(line: str) -> bool:
9696 return len (diff_lines )
9797
9898
99- def format_generated_code (generated_test_source : str ) -> str :
100- return re .sub (r"\n{2,}" , "\n \n " , generated_test_source )
101- # formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled"
102- # if formatter_name == "disabled":
103- # return re.sub(r"\n{2,}", "\n\n", generated_test_source)
104- # # try running formatter, if nothing changes (could be due to formatting failing or no actual formatting needed)
105- # original_temp, test_dir_str, exit_on_failure = None, None, True
106- # formatted_temp, formatted_code, changed = apply_formatter_cmds(
107- # formatter_cmds, original_temp, test_dir_str, print_status=False, exit_on_failure=exit_on_failure
108- # )
109- # if not changed:
110- # return re.sub(r"\n{2,}", "\n\n", formatted_code)
111- # return formatted_code
99+ def format_generated_code (generated_test_source : str , formatter_cmds : Union [list [str ], None ]) -> str :
100+ formatter_name = formatter_cmds [0 ].lower () if formatter_cmds else "disabled"
101+ if formatter_name == "disabled" :
102+ return re .sub (r"\n{2,}" , "\n \n " , generated_test_source )
103+ with tempfile .TemporaryDirectory () as test_dir_str :
104+ # try running formatter, if nothing changes (could be due to formatting failing or no actual formatting needed)
105+ original_temp = Path (test_dir_str ) / "original_temp.py"
106+ original_temp .write_text (generated_test_source , encoding = "utf8" )
107+ _ , formatted_code , changed = apply_formatter_cmds (
108+ formatter_cmds , original_temp , test_dir_str , print_status = False , exit_on_failure = False
109+ )
110+ if not changed :
111+ return re .sub (r"\n{2,}" , "\n \n " , formatted_code )
112+ return formatted_code
112113
113114
114115def format_code (
0 commit comments