Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/unit-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
run: uvx poetry install --with dev

- name: Unit tests
run: uvx poetry run pytest tests/ --cov --cov-report=xml
run: uvx poetry run pytest tests/ --cov --cov-report=xml --disable-warnings

- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ celerybeat.pid
# Environments
.env
**/.env
.venv
.venv*
env/
venv/
ENV/
Expand Down
25 changes: 18 additions & 7 deletions codeflash/code_utils/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
import os
import shlex
import subprocess
from functools import partial
from typing import TYPE_CHECKING

import black
import isort

from codeflash.cli_cmds.console import console, logger

if TYPE_CHECKING:
from pathlib import Path

imports_sort = partial(isort.code, float_to_top=True)


def format_code(formatter_cmds: list[str], path: Path) -> str:
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
Expand Down Expand Up @@ -46,12 +50,19 @@ def format_code(formatter_cmds: list[str], path: Path) -> str:
return path.read_text(encoding="utf8")


def sort_imports(code: str) -> str:
def format_code_in_memory(code: str, *, imports_only: bool = False) -> str:
if imports_only:
try:
sorted_code = imports_sort(code)
except Exception: # noqa: BLE001
logger.debug("Failed to sort imports with isort.")
return code
return sorted_code
try:
# Deduplicate and sort imports, modify the code in memory, not on disk
sorted_code = isort.code(code)
except Exception:
logger.exception("Failed to sort imports with isort.")
return code # Fall back to original code if isort fails
formatted_code = black.format_str(code, mode=black.FileMode())
formatted_code = imports_sort(formatted_code)
except Exception: # noqa: BLE001
logger.debug("Failed to format code with black.")
return code

return sorted_code
return formatted_code
6 changes: 2 additions & 4 deletions codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
from pathlib import Path
from typing import TYPE_CHECKING

import isort

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
from codeflash.code_utils.formatter import format_code_in_memory
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent, TestingMode, VerificationType

Expand Down Expand Up @@ -355,8 +354,7 @@ def inject_profiling_into_existing_test(
if test_framework == "unittest":
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
tree.body = [*new_imports, create_wrapper_function(mode), *tree.body]
return True, isort.code(ast.unparse(tree), float_to_top=True)

return True, format_code_in_memory(ast.unparse(tree))

def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.FunctionDef:
lineno = 1
Expand Down
6 changes: 3 additions & 3 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
N_TESTS_TO_GENERATE,
TOTAL_LOOPING_TIME,
)
from codeflash.code_utils.formatter import format_code, sort_imports
from codeflash.code_utils.formatter import format_code, format_code_in_memory
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
from codeflash.code_utils.line_profile_utils import add_decorator_imports
from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests
Expand Down Expand Up @@ -541,14 +541,14 @@ def reformat_code_and_helpers(

new_code = format_code(self.args.formatter_cmds, path)
if should_sort_imports:
new_code = sort_imports(new_code)
new_code = format_code_in_memory(new_code, imports_only=True)

new_helper_code: dict[Path, str] = {}
helper_functions_paths = {hf.file_path for hf in helper_functions}
for module_abspath in helper_functions_paths:
formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath)
if should_sort_imports:
formatted_helper_code = sort_imports(formatted_helper_code)
formatted_helper_code = format_code_in_memory(formatted_helper_code, imports_only=True)
new_helper_code[module_abspath] = formatted_helper_code

return new_code, new_helper_code
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ lxml = ">=5.3.0"
crosshair-tool = ">=0.0.78"
coverage = ">=7.6.4"
line_profiler=">=4.2.0" #this is the minimum version which supports python 3.13
black = "^25.1.0"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we stopped requiring black since many users don't use black. Why do we require black, and not have it as an option?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this let's us not worry about fstring mismatch across all python versions, and instead let's black take care of it.

[tool.poetry.group.dev]
optional = true

Expand Down
14 changes: 7 additions & 7 deletions tests/test_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,29 @@
import pytest

from codeflash.code_utils.config_parser import parse_config_file
from codeflash.code_utils.formatter import format_code, sort_imports
from codeflash.code_utils.formatter import format_code, format_code_in_memory


def test_remove_duplicate_imports():
"""Test that duplicate imports are removed when should_sort_imports is True."""
original_code = "import os\nimport os\n"
new_code = sort_imports(original_code)
new_code = format_code_in_memory(original_code, imports_only=True)
assert new_code == "import os\n"


def test_remove_multiple_duplicate_imports():
"""Test that multiple duplicate imports are removed when should_sort_imports is True."""
original_code = "import sys\nimport os\nimport sys\n"

new_code = sort_imports(original_code)
new_code = format_code_in_memory(original_code, imports_only=True)
assert new_code == "import os\nimport sys\n"


def test_sorting_imports():
"""Test that imports are sorted when should_sort_imports is True."""
original_code = "import sys\nimport unittest\nimport os\n"

new_code = sort_imports(original_code)
new_code = format_code_in_memory(original_code, imports_only=True)
assert new_code == "import os\nimport sys\nimport unittest\n"


Expand All @@ -40,7 +40,7 @@ def test_sort_imports_without_formatting():

new_code = format_code(formatter_cmds=["disabled"], path=tmp_path)
assert new_code is not None
new_code = sort_imports(new_code)
new_code = format_code_in_memory(new_code, imports_only=True)
assert new_code == "import os\nimport sys\nimport unittest\n"


Expand All @@ -63,7 +63,7 @@ def foo():
return os.path.join(sys.path[0], 'bar')
"""

actual = sort_imports(original_code)
actual = format_code_in_memory(original_code, imports_only=True)

assert actual == expected

Expand All @@ -90,7 +90,7 @@ def foo():
return os.path.join(sys.path[0], 'bar')
"""

actual = sort_imports(original_code)
actual = format_code_in_memory(original_code, imports_only=True)

assert actual == expected

Expand Down
11 changes: 7 additions & 4 deletions tests/test_instrument_all_and_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from codeflash.models.models import CodePosition, FunctionParent, TestFile, TestFiles, TestingMode, TestType
from codeflash.optimization.optimizer import Optimizer
from codeflash.verification.equivalence import compare_test_results
from codeflash.code_utils.formatter import format_code_in_memory
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture

# Used by cli instrumentation
Expand Down Expand Up @@ -119,10 +120,12 @@ def test_sort():
os.chdir(original_cwd)
assert success
assert new_test is not None
assert new_test.replace('"', "'") == expected.format(
assert format_code_in_memory(new_test) == format_code_in_memory(
expected.format(
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
).replace('"', "'")
)
)

with test_path.open("w") as f:
f.write(new_test)
Expand Down Expand Up @@ -307,9 +310,9 @@ def test_sort():
Path(f.name), [CodePosition(7, 13), CodePosition(12, 13)], fto, Path(f.name).parent, "pytest"
)
assert success
assert new_test.replace('"', "'") == expected.format(
assert format_code_in_memory(new_test) == format_code_in_memory(expected.format(
module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values"))
).replace('"', "'")
))
tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve()
test_path = tests_root / "test_class_method_behavior_results_temp.py"
test_path_perf = tests_root / "test_class_method_behavior_results_perf_temp.py"
Expand Down
Loading
Loading