55import pytest
66
77from codeflash .code_utils .config_parser import parse_config_file
8- from codeflash .code_utils .formatter import format_code , sort_imports
8+ from codeflash .code_utils .formatter import format_code , format_code_in_memory
99
1010
1111def test_remove_duplicate_imports ():
1212 """Test that duplicate imports are removed when should_sort_imports is True."""
1313 original_code = "import os\n import os\n "
14- new_code = sort_imports (original_code )
14+ new_code = format_code_in_memory (original_code )
1515 assert new_code == "import os\n "
1616
17-
1817def test_remove_multiple_duplicate_imports ():
1918 """Test that multiple duplicate imports are removed when should_sort_imports is True."""
2019 original_code = "import sys\n import os\n import sys\n "
2120
22- new_code = sort_imports (original_code )
21+ new_code = format_code_in_memory (original_code )
2322 assert new_code == "import os\n import sys\n "
2423
25-
2624def test_sorting_imports ():
2725 """Test that imports are sorted when should_sort_imports is True."""
2826 original_code = "import sys\n import unittest\n import os\n "
2927
30- new_code = sort_imports (original_code )
28+ new_code = format_code_in_memory (original_code )
3129 assert new_code == "import os\n import sys\n import unittest\n "
3230
33-
3431def test_sort_imports_without_formatting ():
3532 """Test that imports are sorted when formatting is disabled and should_sort_imports is True."""
3633 with tempfile .NamedTemporaryFile () as tmp :
37- tmp .write (b"import sys\n import unittest\n import os\n " )
34+ tmp .write (b"import sys\n import unittest\n import os\n import sys \n import unittest \ n " )
3835 tmp .flush ()
3936 tmp_path = Path (tmp .name )
4037
4138 new_code = format_code (formatter_cmds = ["disabled" ], path = tmp_path )
4239 assert new_code is not None
43- new_code = sort_imports (new_code )
40+ new_code = format_code_in_memory (new_code )
4441 assert new_code == "import os\n import sys\n import unittest\n "
4542
46-
4743def test_dedup_and_sort_imports_deduplicates ():
4844 original_code = """
4945import os
@@ -54,16 +50,15 @@ def foo():
5450 return os.path.join(sys.path[0], 'bar')
5551"""
5652
57- expected = """
58- import os
53+ expected = """import os
5954import sys
6055
6156
6257def foo():
63- return os.path.join(sys.path[0], ' bar' )
58+ return os.path.join(sys.path[0], " bar" )
6459"""
6560
66- actual = sort_imports (original_code )
61+ actual = format_code_in_memory (original_code )
6762
6863 assert actual == expected
6964
@@ -80,17 +75,16 @@ def foo():
8075 return os.path.join(sys.path[0], 'bar')
8176"""
8277
83- expected = """
84- import json
78+ expected = """import json
8579import os
8680import sys
8781
8882
8983def foo():
90- return os.path.join(sys.path[0], ' bar' )
84+ return os.path.join(sys.path[0], " bar" )
9185"""
9286
93- actual = sort_imports (original_code )
87+ actual = format_code_in_memory (original_code )
9488
9589 assert actual == expected
9690
0 commit comments