Skip to content

Commit c71d2da

Browse files
committed
tidy up
1 parent 389b32c commit c71d2da

File tree

2 files changed

+40
-17
lines changed

2 files changed

+40
-17
lines changed

codeflash/code_utils/env_utils.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import json
44
import os
5+
import shlex
6+
import shutil
57
import tempfile
68
from functools import lru_cache
79
from pathlib import Path
@@ -14,21 +16,41 @@
1416

1517

1618
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool: # noqa
17-
return_code = True
18-
if formatter_cmds[0] == "disabled":
19-
return return_code
19+
if not formatter_cmds or formatter_cmds[0] == "disabled":
20+
return True
21+
22+
first_cmd = formatter_cmds[0]
23+
cmd_tokens = shlex.split(first_cmd) if isinstance(first_cmd, str) else [first_cmd]
24+
25+
if not cmd_tokens:
26+
return True
27+
28+
exe_name = cmd_tokens[0]
29+
command_str = " ".join(formatter_cmds).replace(" $file", "")
30+
31+
if shutil.which(exe_name) is None:
32+
logger.error(
33+
f"Could not find formatter: {command_str}\n"
34+
f"Please install it or update 'formatter-cmds' in your codeflash configuration"
35+
)
36+
return False
37+
2038
tmp_code = """print("hello world")"""
21-
with tempfile.TemporaryDirectory() as tmpdir:
22-
tmp_file = Path(tmpdir) / "test_codeflash_formatter.py"
23-
tmp_file.write_text(tmp_code, encoding="utf-8")
24-
try:
25-
format_code(formatter_cmds, tmp_file, print_status=False, exit_on_failure=exit_on_failure)
26-
except Exception:
27-
exit_with_message(
28-
"⚠️ Codeflash requires a code formatter to be installed in your environment, but none was found. Please install a supported formatter, verify the formatter-cmds in your codeflash pyproject.toml config and try again.",
29-
error_on_exit=True,
30-
)
31-
return return_code
39+
try:
40+
with tempfile.TemporaryDirectory() as tmpdir:
41+
tmp_file = Path(tmpdir) / "test_codeflash_formatter.py"
42+
tmp_file.write_text(tmp_code, encoding="utf-8")
43+
format_code(formatter_cmds, tmp_file, print_status=False, exit_on_failure=False)
44+
return True
45+
except FileNotFoundError:
46+
logger.error(
47+
f"Could not find formatter: {command_str}\n"
48+
f"Please install it or update 'formatter-cmds' in your codeflash configuration"
49+
)
50+
return False
51+
except Exception as e:
52+
logger.error(f"Formatter failed to run: {command_str}\nError: {e}")
53+
return False
3254

3355

3456
@lru_cache(maxsize=1)
@@ -138,4 +160,4 @@ def is_ci() -> bool:
138160
def is_pr_draft() -> bool:
139161
"""Check if the PR is draft. in the github action context."""
140162
event = get_cached_gh_event_data()
141-
return bool(event.get("pull_request", {}).get("draft", False))
163+
return bool(event.get("pull_request", {}).get("draft", False))

codeflash/verification/hypothesis_testing.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def generate_hypothesis_tests(
138138
qualified_function_path = get_qualified_function_path(
139139
function_to_optimize.file_path, args.project_root, function_to_optimize.qualified_name
140140
)
141-
logger.info(f"command: hypothesis write {function_to_optimize.file_path.stem}")
141+
logger.info(f"command: hypothesis write {qualified_function_path}")
142142

143143
hypothesis_result = subprocess.run(
144144
["hypothesis", "write", qualified_function_path],
@@ -182,8 +182,9 @@ def visit_FunctionDef(self, node): # noqa: ANN001, ANN202
182182
unparsed = ast.unparse(modified_tree)
183183

184184
hypothesis_test_suite_code = format_code(
185-
make_hypothesis_tests_deterministic(remove_functions_with_only_any_type(unparsed)),
185+
args.formatter_cmds,
186186
function_to_optimize.file_path,
187+
optimized_code=make_hypothesis_tests_deterministic(remove_functions_with_only_any_type(unparsed)),
187188
)
188189
with hypothesis_path.open("w", encoding="utf-8") as f:
189190
f.write(hypothesis_test_suite_code)

0 commit comments

Comments
 (0)