55import pytest
66
77from codeflash .code_utils .config_parser import parse_config_file
8- from codeflash .code_utils .formatter import format_code , sort_imports
8+ from codeflash .code_utils .formatter import format_code , format_code_in_memory
99
1010
1111def test_remove_duplicate_imports ():
1212 """Test that duplicate imports are removed when should_sort_imports is True."""
1313 original_code = "import os\n import os\n "
14- new_code = sort_imports (original_code )
14+ new_code = format_code_in_memory (original_code , imports_only = True )
1515 assert new_code == "import os\n "
1616
1717
1818def test_remove_multiple_duplicate_imports ():
1919 """Test that multiple duplicate imports are removed when should_sort_imports is True."""
2020 original_code = "import sys\n import os\n import sys\n "
2121
22- new_code = sort_imports (original_code )
22+ new_code = format_code_in_memory (original_code , imports_only = True )
2323 assert new_code == "import os\n import sys\n "
2424
2525
2626def test_sorting_imports ():
2727 """Test that imports are sorted when should_sort_imports is True."""
2828 original_code = "import sys\n import unittest\n import os\n "
2929
30- new_code = sort_imports (original_code )
30+ new_code = format_code_in_memory (original_code , imports_only = True )
3131 assert new_code == "import os\n import sys\n import unittest\n "
3232
3333
@@ -40,7 +40,7 @@ def test_sort_imports_without_formatting():
4040
4141 new_code = format_code (formatter_cmds = ["disabled" ], path = tmp_path )
4242 assert new_code is not None
43- new_code = sort_imports (new_code )
43+ new_code = format_code_in_memory (new_code , imports_only = True )
4444 assert new_code == "import os\n import sys\n import unittest\n "
4545
4646
@@ -63,7 +63,7 @@ def foo():
6363 return os.path.join(sys.path[0], 'bar')
6464"""
6565
66- actual = sort_imports (original_code )
66+ actual = format_code_in_memory (original_code , imports_only = True )
6767
6868 assert actual == expected
6969
@@ -90,7 +90,7 @@ def foo():
9090 return os.path.join(sys.path[0], 'bar')
9191"""
9292
93- actual = sort_imports (original_code )
93+ actual = format_code_in_memory (original_code , imports_only = True )
9494
9595 assert actual == expected
9696
0 commit comments