Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions codeflash/benchmarking/instrument_codeflash_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

from typing import TYPE_CHECKING, Optional, Union

import isort
import libcst as cst

from codeflash.code_utils.formatter import sort_imports

if TYPE_CHECKING:
from pathlib import Path

Expand Down Expand Up @@ -107,7 +108,7 @@ def instrument_codeflash_trace_decorator(file_to_funcs_to_optimize: dict[Path, l
original_code = file_path.read_text(encoding="utf-8")
new_code = add_codeflash_decorator_to_code(original_code, functions_to_optimize)
# Modify the code
modified_code = isort.code(code=new_code, float_to_top=True)
modified_code = sort_imports(code=new_code, float_to_top=True)

# Write the modified code back to the file
file_path.write_text(modified_code, encoding="utf-8")
5 changes: 2 additions & 3 deletions codeflash/benchmarking/replay_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any

import isort

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.formatter import sort_imports
from codeflash.discovery.functions_to_optimize import inspect_top_level_functions_or_methods
from codeflash.verification.verification_utils import get_test_file_path

Expand Down Expand Up @@ -299,7 +298,7 @@ def generate_replay_test(
test_framework=test_framework,
max_run_count=max_run_count,
)
test_code = isort.code(test_code)
test_code = sort_imports(code=test_code)
output_file = get_test_file_path(
test_dir=Path(output_dir), function_name=benchmark_module_path, test_type="replay"
)
Expand Down
4 changes: 2 additions & 2 deletions codeflash/code_utils/code_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from functools import lru_cache
from typing import TYPE_CHECKING, Optional, TypeVar

import isort
import libcst as cst
from libcst.metadata import PositionProvider

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module
from codeflash.code_utils.config_parser import find_conftest_files
from codeflash.code_utils.formatter import sort_imports
from codeflash.code_utils.line_profile_utils import ImportAdder
from codeflash.models.models import FunctionParent

Expand Down Expand Up @@ -226,7 +226,7 @@ def add_custom_marker_to_all_tests(test_paths: list[Path]) -> None:
module = cst.parse_module(file_content)
importadder = ImportAdder("import pytest")
modified_module = module.visit(importadder)
modified_module = cst.parse_module(isort.code(modified_module.code, float_to_top=True))
modified_module = cst.parse_module(sort_imports(code=modified_module.code, float_to_top=True))
pytest_mark_adder = PytestMarkAdder("codeflash_no_autouse")
modified_module = modified_module.visit(pytest_mark_adder)
test_path.write_text(modified_module.code, encoding="utf-8")
Expand Down
6 changes: 3 additions & 3 deletions codeflash/code_utils/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,11 @@ def format_code(
return formatted_code


def sort_imports(code: str) -> str:
def sort_imports(code: str, *, float_to_top: bool = False) -> str:
try:
# Deduplicate and sort imports, modify the code in memory, not on disk
sorted_code = isort.code(code)
except Exception:
sorted_code = isort.code(code=code, float_to_top=float_to_top)
except Exception: # this will also catch the FileSkipComment exception, use this fn everywhere
logger.exception("Failed to sort imports with isort.")
return code # Fall back to original code if isort fails

Expand Down
3 changes: 2 additions & 1 deletion codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
from codeflash.code_utils.formatter import sort_imports
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent, TestingMode, VerificationType

Expand Down Expand Up @@ -1129,7 +1130,7 @@ def add_async_decorator_to_function(
import_transformer = AsyncDecoratorImportAdder(mode)
module = module.visit(import_transformer)

return isort.code(module.code, float_to_top=True), decorator_transformer.added_decorator
return sort_imports(code=module.code, float_to_top=True), decorator_transformer.added_decorator
except Exception as e:
logger.exception(f"Error adding async decorator to function {function.qualified_name}: {e}")
return source_code, False
Expand Down
4 changes: 2 additions & 2 deletions codeflash/code_utils/line_profile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from pathlib import Path
from typing import TYPE_CHECKING, Union

import isort
import libcst as cst

from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.formatter import sort_imports

if TYPE_CHECKING:
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
Expand Down Expand Up @@ -213,7 +213,7 @@ def add_decorator_imports(function_to_optimize: FunctionToOptimize, code_context
transformer = ImportAdder("from line_profiler import profile as codeflash_line_profile")
# Apply the transformer to add the import
module_node = module_node.visit(transformer)
modified_code = isort.code(module_node.code, float_to_top=True)
modified_code = sort_imports(code=module_node.code, float_to_top=True)
# write to file
with file_path.open("w", encoding="utf-8") as file:
file.write(modified_code)
Expand Down
3 changes: 1 addition & 2 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from pathlib import Path
from typing import TYPE_CHECKING

import isort
import libcst as cst
from rich.console import Group
from rich.panel import Panel
Expand Down Expand Up @@ -900,7 +899,7 @@ def reformat_code_and_helpers(
optimized_context: CodeStringsMarkdown,
) -> tuple[str, dict[Path, str]]:
should_sort_imports = not self.args.disable_imports_sorting
if should_sort_imports and isort.code(original_code) != original_code:
if should_sort_imports and sort_imports(code=original_code) != original_code:
should_sort_imports = False

optimized_code = ""
Expand Down
5 changes: 4 additions & 1 deletion codeflash/tracing/tracing_new_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ def __exit__(

# These modules have been imported here now the tracer is done. It is safe to import codeflash and external modules here

from contextlib import suppress

import isort

from codeflash.tracing.replay_test import create_trace_replay_test
Expand All @@ -280,7 +282,8 @@ def __exit__(
test_file_path = get_test_file_path(
test_dir=Path(self.config["tests_root"]), function_name=function_path, test_type="replay"
)
replay_test = isort.code(replay_test)
with suppress(Exception):
replay_test = isort.code(replay_test)

with Path(test_file_path).open("w", encoding="utf8") as file:
file.write(replay_test)
Expand Down
5 changes: 2 additions & 3 deletions codeflash/verification/instrument_codeflash_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from pathlib import Path
from typing import TYPE_CHECKING

import isort

from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.formatter import sort_imports

if TYPE_CHECKING:
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
Expand Down Expand Up @@ -70,7 +69,7 @@ def add_codeflash_capture_to_init(
ast.fix_missing_locations(modified_tree)

# Convert back to source code
return isort.code(code=ast.unparse(modified_tree), float_to_top=True)
return sort_imports(code=ast.unparse(modified_tree), float_to_top=True)


class InitDecorator(ast.NodeTransformer):
Expand Down
Loading