diff --git a/codeflash/benchmarking/instrument_codeflash_trace.py b/codeflash/benchmarking/instrument_codeflash_trace.py index 099954243..f54674cfa 100644 --- a/codeflash/benchmarking/instrument_codeflash_trace.py +++ b/codeflash/benchmarking/instrument_codeflash_trace.py @@ -2,9 +2,10 @@ from typing import TYPE_CHECKING, Optional, Union -import isort import libcst as cst +from codeflash.code_utils.formatter import sort_imports + if TYPE_CHECKING: from pathlib import Path @@ -107,7 +108,7 @@ def instrument_codeflash_trace_decorator(file_to_funcs_to_optimize: dict[Path, l original_code = file_path.read_text(encoding="utf-8") new_code = add_codeflash_decorator_to_code(original_code, functions_to_optimize) # Modify the code - modified_code = isort.code(code=new_code, float_to_top=True) + modified_code = sort_imports(code=new_code, float_to_top=True) # Write the modified code back to the file file_path.write_text(modified_code, encoding="utf-8") diff --git a/codeflash/benchmarking/replay_test.py b/codeflash/benchmarking/replay_test.py index 7b125dfa5..e9f66dc8a 100644 --- a/codeflash/benchmarking/replay_test.py +++ b/codeflash/benchmarking/replay_test.py @@ -6,9 +6,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Any -import isort - from codeflash.cli_cmds.console import logger +from codeflash.code_utils.formatter import sort_imports from codeflash.discovery.functions_to_optimize import inspect_top_level_functions_or_methods from codeflash.verification.verification_utils import get_test_file_path @@ -299,7 +298,7 @@ def generate_replay_test( test_framework=test_framework, max_run_count=max_run_count, ) - test_code = isort.code(test_code) + test_code = sort_imports(code=test_code) output_file = get_test_file_path( test_dir=Path(output_dir), function_name=benchmark_module_path, test_type="replay" ) diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index e05c70922..b3ee6e34e 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -5,13 +5,13 @@ from functools import lru_cache from typing import TYPE_CHECKING, Optional, TypeVar -import isort import libcst as cst from libcst.metadata import PositionProvider from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module from codeflash.code_utils.config_parser import find_conftest_files +from codeflash.code_utils.formatter import sort_imports from codeflash.code_utils.line_profile_utils import ImportAdder from codeflash.models.models import FunctionParent @@ -226,7 +226,7 @@ def add_custom_marker_to_all_tests(test_paths: list[Path]) -> None: module = cst.parse_module(file_content) importadder = ImportAdder("import pytest") modified_module = module.visit(importadder) - modified_module = cst.parse_module(isort.code(modified_module.code, float_to_top=True)) + modified_module = cst.parse_module(sort_imports(code=modified_module.code, float_to_top=True)) pytest_mark_adder = PytestMarkAdder("codeflash_no_autouse") modified_module = modified_module.visit(pytest_mark_adder) test_path.write_text(modified_module.code, encoding="utf-8") diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index eff9f4ed4..67fd45e17 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -166,11 +166,11 @@ def format_code( return formatted_code -def sort_imports(code: str) -> str: +def sort_imports(code: str, *, float_to_top: bool = False) -> str: try: # Deduplicate and sort imports, modify the code in memory, not on disk - sorted_code = isort.code(code) - except Exception: + sorted_code = isort.code(code=code, float_to_top=float_to_top) + except Exception: # this will also catch the FileSkipComment exception, use this fn everywhere logger.exception("Failed to sort imports with isort.") return code # Fall back to original code if isort fails diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 7e9413fb4..ae3d82b57 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -10,6 +10,7 @@ from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path +from codeflash.code_utils.formatter import sort_imports from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import FunctionParent, TestingMode, VerificationType @@ -1129,7 +1130,7 @@ def add_async_decorator_to_function( import_transformer = AsyncDecoratorImportAdder(mode) module = module.visit(import_transformer) - return isort.code(module.code, float_to_top=True), decorator_transformer.added_decorator + return sort_imports(code=module.code, float_to_top=True), decorator_transformer.added_decorator except Exception as e: logger.exception(f"Error adding async decorator to function {function.qualified_name}: {e}") return source_code, False diff --git a/codeflash/code_utils/line_profile_utils.py b/codeflash/code_utils/line_profile_utils.py index 8f8fdf661..27571dd0b 100644 --- a/codeflash/code_utils/line_profile_utils.py +++ b/codeflash/code_utils/line_profile_utils.py @@ -6,10 +6,10 @@ from pathlib import Path from typing import TYPE_CHECKING, Union -import isort import libcst as cst from codeflash.code_utils.code_utils import get_run_tmp_file +from codeflash.code_utils.formatter import sort_imports if TYPE_CHECKING: from codeflash.discovery.functions_to_optimize import FunctionToOptimize @@ -213,7 +213,7 @@ def add_decorator_imports(function_to_optimize: FunctionToOptimize, code_context transformer = ImportAdder("from line_profiler import profile as codeflash_line_profile") # Apply the transformer to add the import module_node = module_node.visit(transformer) - modified_code = isort.code(module_node.code, float_to_top=True) + modified_code = sort_imports(code=module_node.code, float_to_top=True) # write to file with file_path.open("w", encoding="utf-8") as file: file.write(modified_code) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index e54aac92d..723a6954d 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -12,7 +12,6 @@ from pathlib import Path from typing import TYPE_CHECKING -import isort import libcst as cst from rich.console import Group from rich.panel import Panel @@ -900,7 +899,7 @@ def reformat_code_and_helpers( optimized_context: CodeStringsMarkdown, ) -> tuple[str, dict[Path, str]]: should_sort_imports = not self.args.disable_imports_sorting - if should_sort_imports and isort.code(original_code) != original_code: + if should_sort_imports and sort_imports(code=original_code) != original_code: should_sort_imports = False optimized_code = "" diff --git a/codeflash/tracing/tracing_new_process.py b/codeflash/tracing/tracing_new_process.py index e17ff14e9..ec1794f09 100644 --- a/codeflash/tracing/tracing_new_process.py +++ b/codeflash/tracing/tracing_new_process.py @@ -265,6 +265,8 @@ def __exit__( # These modules have been imported here now the tracer is done. It is safe to import codeflash and external modules here + from contextlib import suppress + import isort from codeflash.tracing.replay_test import create_trace_replay_test @@ -280,7 +282,8 @@ def __exit__( test_file_path = get_test_file_path( test_dir=Path(self.config["tests_root"]), function_name=function_path, test_type="replay" ) - replay_test = isort.code(replay_test) + with suppress(Exception): + replay_test = isort.code(replay_test) with Path(test_file_path).open("w", encoding="utf8") as file: file.write(replay_test) diff --git a/codeflash/verification/instrument_codeflash_capture.py b/codeflash/verification/instrument_codeflash_capture.py index adcb66ef8..19a29013b 100644 --- a/codeflash/verification/instrument_codeflash_capture.py +++ b/codeflash/verification/instrument_codeflash_capture.py @@ -4,9 +4,8 @@ from pathlib import Path from typing import TYPE_CHECKING -import isort - from codeflash.code_utils.code_utils import get_run_tmp_file +from codeflash.code_utils.formatter import sort_imports if TYPE_CHECKING: from codeflash.discovery.functions_to_optimize import FunctionToOptimize @@ -70,7 +69,7 @@ def add_codeflash_capture_to_init( ast.fix_missing_locations(modified_tree) # Convert back to source code - return isort.code(code=ast.unparse(modified_tree), float_to_top=True) + return sort_imports(code=ast.unparse(modified_tree), float_to_top=True) class InitDecorator(ast.NodeTransformer): diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 40b9c3469..1703f572b 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -796,4 +796,12 @@ def _is_valid(self, item): optimization_function = """def process(self,data): '''Single quote docstring with formatting issues.''' return{'result':[item for item in data if self._is_valid(item)]}""" - _run_formatting_test(source_code, True, optimized_function=optimization_function, expected=expected) \ No newline at end of file + _run_formatting_test(source_code, True, optimized_function=optimization_function, expected=expected) + +def test_sort_imports_skip_file(): + """Test that isort skips files with # isort:skip_file.""" + code = """# isort:skip_file + +import sys, os, json # isort will ignore this file completely""" + new_code = sort_imports(code) + assert new_code == code