Skip to content

Commit eae756a

Browse files
committed
Clarified docstring for get_modification_code_ranges
... and started adding tests for that function.
1 parent 6dd72cf commit eae756a

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

codeflash/code_utils/formatter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,14 @@ def sort_imports(code: str) -> str:
6060

6161
return sorted_code
6262

63-
# TODO(zomglings): Write unit tests.
6463
def get_modification_code_ranges(
6564
modified_code: str,
6665
fto: FunctionToOptimize,
6766
preexisting_functions: set[tuple[str, tuple[FunctionParent,...]]],
6867
helper_functions: list[FunctionSource],
6968
) -> list[tuple[int, int]]:
7069
"""
71-
Returns the line number of modified and new functions in a string containing containing the code in a fully modified file.
70+
Returns the starting and ending line numbers of modified and new functions in a file containing edits.
7271
"""
7372
modified_functions = set()
7473
modified_functions.add(fto.qualified_name)

tests/test_formatter.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import pytest
66

77
from codeflash.code_utils.config_parser import parse_config_file
8-
from codeflash.code_utils.formatter import format_code, sort_imports
8+
from codeflash.code_utils.formatter import format_code, get_modification_code_ranges, sort_imports
9+
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
910

1011

1112
def test_remove_duplicate_imports():
@@ -209,3 +210,15 @@ def foo():
209210
tmp_path = tmp.name
210211
with pytest.raises(FileNotFoundError):
211212
format_code(formatter_cmds=["exit 1"], path=Path(tmp_path))
213+
214+
def test_get_modification_code_ranges_self_contained_fto():
215+
modified_code = """
216+
def hello(name):
217+
print(f"Hello, {{name}}")
218+
"""
219+
220+
fto = FunctionToOptimize(function_name="hello", file_path=Path("hello.py"), parents=[])
221+
code_ranges = get_modification_code_ranges(modified_code, fto, set(), [])
222+
223+
assert len(code_ranges) == 1
224+
assert code_ranges[0] == (2, 3)

0 commit comments

Comments
 (0)