Skip to content

Commit 568074f

Browse files
committed
formatting
1 parent c1302e1 commit 568074f

File tree

7 files changed

+230
-112
lines changed

7 files changed

+230
-112
lines changed

.github/workflows/unit-tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
run: uvx poetry install --with dev
3333

3434
- name: Unit tests
35-
run: uvx poetry run pytest tests/ --cov --cov-report=xml
35+
run: uvx poetry run pytest tests/ --cov --cov-report=xml --disable-warnings
3636

3737
- name: Upload coverage reports to Codecov
3838
uses: codecov/codecov-action@v5

codeflash/code_utils/formatter.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,19 @@
33
import os
44
import shlex
55
import subprocess
6+
from functools import partial
67
from typing import TYPE_CHECKING
78

9+
import black
810
import isort
911

1012
from codeflash.cli_cmds.console import console, logger
1113

1214
if TYPE_CHECKING:
1315
from pathlib import Path
1416

17+
imports_sort = partial(isort.code, float_to_top=True)
18+
1519

1620
def format_code(formatter_cmds: list[str], path: Path) -> str:
1721
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
@@ -46,12 +50,19 @@ def format_code(formatter_cmds: list[str], path: Path) -> str:
4650
return path.read_text(encoding="utf8")
4751

4852

49-
def sort_imports(code: str) -> str:
53+
def format_code_in_memory(code: str, *, imports_only: bool = False) -> str:
54+
if imports_only:
55+
try:
56+
sorted_code = imports_sort(code)
57+
except Exception: # noqa: BLE001
58+
logger.debug("Failed to sort imports with isort.")
59+
return code
60+
return sorted_code
5061
try:
51-
# Deduplicate and sort imports, modify the code in memory, not on disk
52-
sorted_code = isort.code(code)
53-
except Exception:
54-
logger.exception("Failed to sort imports with isort.")
55-
return code # Fall back to original code if isort fails
62+
formatted_code = black.format_str(code, mode=black.FileMode())
63+
formatted_code = imports_sort(formatted_code)
64+
except Exception: # noqa: BLE001
65+
logger.debug("Failed to format code with black.")
66+
return code
5667

57-
return sorted_code
68+
return formatted_code

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
from pathlib import Path
55
from typing import TYPE_CHECKING
66

7-
import isort
8-
97
from codeflash.cli_cmds.console import logger
108
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
9+
from codeflash.code_utils.formatter import format_code_in_memory
1110
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1211
from codeflash.models.models import FunctionParent, TestingMode, VerificationType
1312

@@ -355,8 +354,7 @@ def inject_profiling_into_existing_test(
355354
if test_framework == "unittest":
356355
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
357356
tree.body = [*new_imports, create_wrapper_function(mode), *tree.body]
358-
return True, isort.code(ast.unparse(tree), float_to_top=True)
359-
357+
return True, format_code_in_memory(ast.unparse(tree))
360358

361359
def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.FunctionDef:
362360
lineno = 1

codeflash/optimization/function_optimizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
N_TESTS_TO_GENERATE,
3636
TOTAL_LOOPING_TIME,
3737
)
38-
from codeflash.code_utils.formatter import format_code, sort_imports
38+
from codeflash.code_utils.formatter import format_code, format_code_in_memory
3939
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
4040
from codeflash.code_utils.line_profile_utils import add_decorator_imports
4141
from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests
@@ -541,14 +541,14 @@ def reformat_code_and_helpers(
541541

542542
new_code = format_code(self.args.formatter_cmds, path)
543543
if should_sort_imports:
544-
new_code = sort_imports(new_code)
544+
new_code = format_code_in_memory(new_code, imports_only=True)
545545

546546
new_helper_code: dict[Path, str] = {}
547547
helper_functions_paths = {hf.file_path for hf in helper_functions}
548548
for module_abspath in helper_functions_paths:
549549
formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath)
550550
if should_sort_imports:
551-
formatted_helper_code = sort_imports(formatted_helper_code)
551+
formatted_helper_code = format_code_in_memory(formatted_helper_code, imports_only=True)
552552
new_helper_code[module_abspath] = formatted_helper_code
553553

554554
return new_code, new_helper_code

tests/test_formatter.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,29 @@
55
import pytest
66

77
from codeflash.code_utils.config_parser import parse_config_file
8-
from codeflash.code_utils.formatter import format_code, sort_imports
8+
from codeflash.code_utils.formatter import format_code, format_code_in_memory
99

1010

1111
def test_remove_duplicate_imports():
1212
"""Test that duplicate imports are removed when should_sort_imports is True."""
1313
original_code = "import os\nimport os\n"
14-
new_code = sort_imports(original_code)
14+
new_code = format_code_in_memory(original_code, imports_only=True)
1515
assert new_code == "import os\n"
1616

1717

1818
def test_remove_multiple_duplicate_imports():
1919
"""Test that multiple duplicate imports are removed when should_sort_imports is True."""
2020
original_code = "import sys\nimport os\nimport sys\n"
2121

22-
new_code = sort_imports(original_code)
22+
new_code = format_code_in_memory(original_code, imports_only=True)
2323
assert new_code == "import os\nimport sys\n"
2424

2525

2626
def test_sorting_imports():
2727
"""Test that imports are sorted when should_sort_imports is True."""
2828
original_code = "import sys\nimport unittest\nimport os\n"
2929

30-
new_code = sort_imports(original_code)
30+
new_code = format_code_in_memory(original_code, imports_only=True)
3131
assert new_code == "import os\nimport sys\nimport unittest\n"
3232

3333

@@ -40,7 +40,7 @@ def test_sort_imports_without_formatting():
4040

4141
new_code = format_code(formatter_cmds=["disabled"], path=tmp_path)
4242
assert new_code is not None
43-
new_code = sort_imports(new_code)
43+
new_code = format_code_in_memory(new_code, imports_only=True)
4444
assert new_code == "import os\nimport sys\nimport unittest\n"
4545

4646

@@ -63,7 +63,7 @@ def foo():
6363
return os.path.join(sys.path[0], 'bar')
6464
"""
6565

66-
actual = sort_imports(original_code)
66+
actual = format_code_in_memory(original_code, imports_only=True)
6767

6868
assert actual == expected
6969

@@ -90,7 +90,7 @@ def foo():
9090
return os.path.join(sys.path[0], 'bar')
9191
"""
9292

93-
actual = sort_imports(original_code)
93+
actual = format_code_in_memory(original_code, imports_only=True)
9494

9595
assert actual == expected
9696

tests/test_instrument_all_and_run.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from codeflash.models.models import CodePosition, FunctionParent, TestFile, TestFiles, TestingMode, TestType
1313
from codeflash.optimization.optimizer import Optimizer
1414
from codeflash.verification.equivalence import compare_test_results
15+
from codeflash.code_utils.formatter import format_code_in_memory
1516
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
1617

1718
# Used by cli instrumentation
@@ -119,10 +120,12 @@ def test_sort():
119120
os.chdir(original_cwd)
120121
assert success
121122
assert new_test is not None
122-
assert new_test.replace('"', "'") == expected.format(
123+
assert format_code_in_memory(new_test) == format_code_in_memory(
124+
expected.format(
123125
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_results_temp",
124126
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
125-
).replace('"', "'")
127+
)
128+
)
126129

127130
with test_path.open("w") as f:
128131
f.write(new_test)
@@ -307,9 +310,9 @@ def test_sort():
307310
Path(f.name), [CodePosition(7, 13), CodePosition(12, 13)], fto, Path(f.name).parent, "pytest"
308311
)
309312
assert success
310-
assert new_test.replace('"', "'") == expected.format(
313+
assert format_code_in_memory(new_test) == format_code_in_memory(expected.format(
311314
module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values"))
312-
).replace('"', "'")
315+
))
313316
tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve()
314317
test_path = tests_root / "test_class_method_behavior_results_temp.py"
315318
test_path_perf = tests_root / "test_class_method_behavior_results_perf_temp.py"

0 commit comments

Comments
 (0)