Skip to content

Commit 00473e8

Browse files
committed
update tests
1 parent 1056b33 commit 00473e8

File tree

6 files changed

+64
-49
lines changed

6 files changed

+64
-49
lines changed

codeflash/code_utils/formatter.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import subprocess
66
from typing import TYPE_CHECKING
77

8-
import isort
8+
from black import Mode, format_str
9+
from isort import code as imports_format
910

1011
from codeflash.cli_cmds.console import console, logger
1112

@@ -24,19 +25,31 @@ def format_code(formatter_cmds: list[str], path: Path) -> str:
2425
file_token = "$file" # noqa: S105
2526
for command in set(formatter_cmds):
2627
formatter_cmd_list = shlex.split(command, posix=os.name != "nt")
27-
formatter_cmd_list = [path.as_posix() if chunk == file_token else chunk for chunk in formatter_cmd_list]
28+
formatter_cmd_list = [
29+
path.as_posix() if chunk == file_token else chunk
30+
for chunk in formatter_cmd_list
31+
]
2832
try:
29-
result = subprocess.run(formatter_cmd_list, capture_output=True, check=False)
33+
result = subprocess.run(
34+
formatter_cmd_list, capture_output=True, check=False
35+
)
3036
if result.returncode == 0:
31-
console.rule(f"Formatted Successfully with: {formatter_name.replace('$file', path.name)}")
37+
console.rule(
38+
f"Formatted Successfully with: {formatter_name.replace('$file', path.name)}"
39+
)
3240
else:
33-
logger.error(f"Failed to format code with {' '.join(formatter_cmd_list)}")
41+
logger.error(
42+
f"Failed to format code with {' '.join(formatter_cmd_list)}"
43+
)
3444
except FileNotFoundError as e:
3545
from rich.panel import Panel
3646
from rich.text import Text
3747

3848
panel = Panel(
39-
Text.from_markup(f"⚠️ Formatter command not found: {' '.join(formatter_cmd_list)}", style="bold red"),
49+
Text.from_markup(
50+
f"⚠️ Formatter command not found: {' '.join(formatter_cmd_list)}",
51+
style="bold red",
52+
),
4053
expand=False,
4154
)
4255
console.print(panel)
@@ -46,12 +59,22 @@ def format_code(formatter_cmds: list[str], path: Path) -> str:
4659
return path.read_text(encoding="utf8")
4760

4861

49-
def sort_imports(code: str) -> str:
50-
try:
51-
# Deduplicate and sort imports, modify the code in memory, not on disk
52-
sorted_code = isort.code(code)
53-
except Exception:
54-
logger.exception("Failed to sort imports with isort.")
55-
return code # Fall back to original code if isort fails
62+
def format_code_in_memory(code_to_format: str, *, imports_only: bool = False) -> str:
63+
64+
if imports_only:
65+
return imports_format(code_to_format)
5666

57-
return sorted_code
67+
formatters = [
68+
(format_str, {"mode": Mode()}),
69+
(imports_format, {"float_to_top": True}),
70+
]
71+
72+
formatted_code = code_to_format
73+
try:
74+
for formatter_func, formatter_kwargs in formatters:
75+
formatted_code = formatter_func(formatted_code, **formatter_kwargs)
76+
except (ValueError, TypeError, SyntaxError) as e:
77+
logger.debug(f"Failed to format: {e}")
78+
return code_to_format
79+
else:
80+
return formatted_code

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
from pathlib import Path
55
from typing import TYPE_CHECKING
66

7-
import isort
8-
97
from codeflash.cli_cmds.console import logger
108
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
9+
from codeflash.code_utils.formatter import format_code_in_memory
1110
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1211
from codeflash.models.models import FunctionParent, TestingMode, VerificationType
1312

@@ -355,8 +354,7 @@ def inject_profiling_into_existing_test(
355354
if test_framework == "unittest":
356355
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
357356
tree.body = [*new_imports, create_wrapper_function(mode), *tree.body]
358-
return True, isort.code(ast.unparse(tree), float_to_top=True)
359-
357+
return True, format_code_in_memory(ast.unparse(tree))
360358

361359
def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.FunctionDef:
362360
lineno = 1

codeflash/optimization/function_optimizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
N_TESTS_TO_GENERATE,
3636
TOTAL_LOOPING_TIME,
3737
)
38-
from codeflash.code_utils.formatter import format_code, sort_imports
38+
from codeflash.code_utils.formatter import format_code, format_code_in_memory
3939
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
4040
from codeflash.code_utils.line_profile_utils import add_decorator_imports
4141
from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests
@@ -541,14 +541,14 @@ def reformat_code_and_helpers(
541541

542542
new_code = format_code(self.args.formatter_cmds, path)
543543
if should_sort_imports:
544-
new_code = sort_imports(new_code)
544+
new_code = format_code_in_memory(new_code, imports_only=True)
545545

546546
new_helper_code: dict[Path, str] = {}
547547
helper_functions_paths = {hf.file_path for hf in helper_functions}
548548
for module_abspath in helper_functions_paths:
549549
formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath)
550550
if should_sort_imports:
551-
formatted_helper_code = sort_imports(formatted_helper_code)
551+
formatted_helper_code = format_code_in_memory(formatted_helper_code, imports_only=True)
552552
new_helper_code[module_abspath] = formatted_helper_code
553553

554554
return new_code, new_helper_code

tests/test_formatter.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,45 +5,41 @@
55
import pytest
66

77
from codeflash.code_utils.config_parser import parse_config_file
8-
from codeflash.code_utils.formatter import format_code, sort_imports
8+
from codeflash.code_utils.formatter import format_code, format_code_in_memory
99

1010

1111
def test_remove_duplicate_imports():
1212
"""Test that duplicate imports are removed when should_sort_imports is True."""
1313
original_code = "import os\nimport os\n"
14-
new_code = sort_imports(original_code)
14+
new_code = format_code_in_memory(original_code)
1515
assert new_code == "import os\n"
1616

17-
1817
def test_remove_multiple_duplicate_imports():
1918
"""Test that multiple duplicate imports are removed when should_sort_imports is True."""
2019
original_code = "import sys\nimport os\nimport sys\n"
2120

22-
new_code = sort_imports(original_code)
21+
new_code = format_code_in_memory(original_code)
2322
assert new_code == "import os\nimport sys\n"
2423

25-
2624
def test_sorting_imports():
2725
"""Test that imports are sorted when should_sort_imports is True."""
2826
original_code = "import sys\nimport unittest\nimport os\n"
2927

30-
new_code = sort_imports(original_code)
28+
new_code = format_code_in_memory(original_code)
3129
assert new_code == "import os\nimport sys\nimport unittest\n"
3230

33-
3431
def test_sort_imports_without_formatting():
3532
"""Test that imports are sorted when formatting is disabled and should_sort_imports is True."""
3633
with tempfile.NamedTemporaryFile() as tmp:
37-
tmp.write(b"import sys\nimport unittest\nimport os\n")
34+
tmp.write(b"import sys\nimport unittest\nimport os\nimport sys\nimport unittest\n")
3835
tmp.flush()
3936
tmp_path = Path(tmp.name)
4037

4138
new_code = format_code(formatter_cmds=["disabled"], path=tmp_path)
4239
assert new_code is not None
43-
new_code = sort_imports(new_code)
40+
new_code = format_code_in_memory(new_code)
4441
assert new_code == "import os\nimport sys\nimport unittest\n"
4542

46-
4743
def test_dedup_and_sort_imports_deduplicates():
4844
original_code = """
4945
import os
@@ -54,16 +50,15 @@ def foo():
5450
return os.path.join(sys.path[0], 'bar')
5551
"""
5652

57-
expected = """
58-
import os
53+
expected = """import os
5954
import sys
6055
6156
6257
def foo():
63-
return os.path.join(sys.path[0], 'bar')
58+
return os.path.join(sys.path[0], "bar")
6459
"""
6560

66-
actual = sort_imports(original_code)
61+
actual = format_code_in_memory(original_code)
6762

6863
assert actual == expected
6964

@@ -80,17 +75,16 @@ def foo():
8075
return os.path.join(sys.path[0], 'bar')
8176
"""
8277

83-
expected = """
84-
import json
78+
expected = """import json
8579
import os
8680
import sys
8781
8882
8983
def foo():
90-
return os.path.join(sys.path[0], 'bar')
84+
return os.path.join(sys.path[0], "bar")
9185
"""
9286

93-
actual = sort_imports(original_code)
87+
actual = format_code_in_memory(original_code)
9488

9589
assert actual == expected
9690

tests/test_instrument_all_and_run.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pathlib import Path
88

99
from codeflash.code_utils.code_utils import get_run_tmp_file
10+
from codeflash.code_utils.formatter import format_code_in_memory
1011
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
1112
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1213
from codeflash.models.models import CodePosition, FunctionParent, TestFile, TestFiles, TestingMode, TestType
@@ -119,10 +120,10 @@ def test_sort():
119120
os.chdir(original_cwd)
120121
assert success
121122
assert new_test is not None
122-
assert new_test.replace('"', "'") == expected.format(
123+
assert format_code_in_memory(new_test) == format_code_in_memory(expected.format(
123124
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_results_temp",
124125
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
125-
).replace('"', "'")
126+
))
126127

127128
with test_path.open("w") as f:
128129
f.write(new_test)
@@ -307,9 +308,9 @@ def test_sort():
307308
Path(f.name), [CodePosition(7, 13), CodePosition(12, 13)], fto, Path(f.name).parent, "pytest"
308309
)
309310
assert success
310-
assert new_test.replace('"', "'") == expected.format(
311+
assert format_code_in_memory(new_test) == format_code_in_memory(expected.format(
311312
module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values"))
312-
).replace('"', "'")
313+
))
313314
tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve()
314315
test_path = tests_root / "test_class_method_behavior_results_temp.py"
315316
test_path_perf = tests_root / "test_class_method_behavior_results_perf_temp.py"

tests/test_instrument_tests.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pathlib import Path
99

1010
from codeflash.code_utils.code_utils import get_run_tmp_file
11+
from codeflash.code_utils.formatter import format_code_in_memory
1112
from codeflash.code_utils.instrument_existing_tests import (
1213
FunctionImportedAsVisitor,
1314
inject_profiling_into_existing_test,
@@ -127,10 +128,8 @@ def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, functi
127128
invocation_id = f'{{line_id}}_{{codeflash_test_index}}'
128129
"""
129130
if sys.version_info < (3, 12):
130-
print("sys.version_info < (3, 12)")
131131
expected += """print(f"!######{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}######!")"""
132132
else:
133-
print("sys.version_info >= (3, 12)")
134133
expected += """print(f'!######{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}######!')"""
135134
expected += """
136135
exception = None
@@ -315,6 +314,8 @@ def test_sort():
315314
"""
316315
+ codeflash_wrap_string
317316
+ """
317+
318+
318319
def test_sort():
319320
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
320321
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
@@ -2845,10 +2846,8 @@ def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, functi
28452846
codeflash_test_index = codeflash_wrap.index[test_id]
28462847
invocation_id = f'{{line_id}}_{{codeflash_test_index}}'
28472848
"""
2848-
if sys.version_info < (3, 12):
2849-
expected += """ print(f"!######{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}######!")"""
2850-
else:
2851-
expected += """ print(f'!######{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}######!')"""
2849+
2850+
expected += """ print(f'!######{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}######!')"""
28522851
expected += """
28532852
exception = None
28542853
gc.disable()

0 commit comments

Comments
 (0)