Skip to content

Commit 4b41ab7

Browse files
fix: unit tests
1 parent 5981d75 commit 4b41ab7

File tree

3 files changed

+20
-14
lines changed

3 files changed

+20
-14
lines changed

tests/test_code_replacement.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,8 @@ def main_method(self):
797797

798798

799799
def test_code_replacement10() -> None:
800-
get_code_output = """from __future__ import annotations
800+
get_code_output = """# file: test_code_replacement.py
801+
from __future__ import annotations
801802
802803
class HelperClass:
803804
def __init__(self, name):
@@ -827,7 +828,7 @@ def main_method(self):
827828
)
828829
func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config)
829830
code_context = func_optimizer.get_code_optimization_context().unwrap()
830-
assert code_context.testgen_context.rstrip() == get_code_output.rstrip()
831+
assert code_context.testgen_context.flat.rstrip() == get_code_output.rstrip()
831832

832833

833834
def test_code_replacement11() -> None:

tests/test_function_dependencies.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,9 @@ def test_class_method_dependencies() -> None:
160160
)
161161
assert code_context.helper_functions[0].qualified_name == "Graph.topologicalSortUtil"
162162
assert (
163-
code_context.testgen_context
164-
== """from collections import defaultdict
163+
code_context.testgen_context.flat
164+
== """# file: test_function_dependencies.py
165+
from collections import defaultdict
165166
166167
class Graph:
167168
def __init__(self, vertices):
@@ -220,8 +221,9 @@ def test_recursive_function_context() -> None:
220221
assert code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.C.calculate_something_3"
221222
assert code_context.helper_functions[1].fully_qualified_name == "test_function_dependencies.C.recursive"
222223
assert (
223-
code_context.testgen_context
224-
== """class C:
224+
code_context.testgen_context.flat
225+
== """# file: test_function_dependencies.py
226+
class C:
225227
def calculate_something_3(self, num):
226228
return num + 1
227229

tests/test_get_helper_code.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,9 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
241241
code_context = ctx_result.unwrap()
242242
assert code_context.helper_functions[0].qualified_name == "AbstractCacheBackend.get_cache_or_call"
243243
assert (
244-
code_context.testgen_context
245-
== f'''_P = ParamSpec("_P")
244+
code_context.testgen_context.flat
245+
== f'''# file: {file_path.relative_to(project_root_path)}
246+
_P = ParamSpec("_P")
246247
_KEY_T = TypeVar("_KEY_T")
247248
_STORE_T = TypeVar("_STORE_T")
248249
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
@@ -394,10 +395,11 @@ def test_bubble_sort_deps() -> None:
394395
function_to_optimize = FunctionToOptimize(
395396
function_name="sorter_deps", file_path=file_path, parents=[], starting_line=None, ending_line=None
396397
)
398+
project_root = file_path.parent.parent.resolve()
397399
test_config = TestConfig(
398400
tests_root=str(file_path.parent / "tests"),
399401
tests_project_rootdir=file_path.parent.resolve(),
400-
project_root_path=file_path.parent.parent.resolve(),
402+
project_root_path=project_root,
401403
test_framework="pytest",
402404
pytest_cmd="pytest",
403405
)
@@ -409,19 +411,20 @@ def test_bubble_sort_deps() -> None:
409411
pytest.fail()
410412
code_context = ctx_result.unwrap()
411413
assert (
412-
code_context.testgen_context
413-
== """from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
414-
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
415-
414+
code_context.testgen_context.flat
415+
== f"""# file: code_to_optimize/bubble_sort_dep1_helper.py
416416
def dep1_comparer(arr, j: int) -> bool:
417417
return arr[j] > arr[j + 1]
418418
419+
# file: code_to_optimize/bubble_sort_dep2_swap.py
419420
def dep2_swap(arr, j):
420421
temp = arr[j]
421422
arr[j] = arr[j + 1]
422423
arr[j + 1] = temp
423424
424-
425+
# file: code_to_optimize/bubble_sort_deps.py
426+
from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
427+
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
425428
426429
def sorter_deps(arr):
427430
for i in range(len(arr)):

0 commit comments

Comments
 (0)