Skip to content

Commit 6c1c2a4

Browse files
committed
tests
1 parent e353f38 commit 6c1c2a4

File tree

1 file changed

+63
-56
lines changed

1 file changed

+63
-56
lines changed

tests/test_code_replacement.py

Lines changed: 63 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2630,13 +2630,12 @@ def test_performance(benchmark):
26302630
def test_normal():
26312631
assert True
26322632
"""
2633+
expected = """def test_normal():
2634+
assert True"""
26332635
tree = ast.parse(source)
26342636
result = self.remover.visit(tree)
26352637

2636-
# Should only have one function left
2637-
functions = [node for node in result.body if isinstance(node, ast.FunctionDef)]
2638-
assert len(functions) == 1
2639-
assert functions[0].name == "test_normal"
2638+
assert ast.unparse(result) == expected
26402639

26412640
def test_removes_async_function_with_benchmark_parameter(self):
26422641
"""Test that async functions with 'benchmark' parameter are removed."""
@@ -2648,13 +2647,13 @@ async def test_async_performance(benchmark):
26482647
async def test_async_normal():
26492648
assert True
26502649
"""
2650+
expected = """async def test_async_normal():
2651+
assert True"""
26512652
tree = ast.parse(source)
26522653
result = self.remover.visit(tree)
26532654

26542655
# Should only have one async function left
2655-
functions = [node for node in result.body if isinstance(node, ast.AsyncFunctionDef)]
2656-
assert len(functions) == 1
2657-
assert functions[0].name == "test_async_normal"
2656+
assert ast.unparse(result) == expected
26582657

26592658
def test_removes_function_with_pytest_mark_benchmark_decorator(self):
26602659
"""Test that functions with @pytest.mark.benchmark decorator are removed."""
@@ -2668,13 +2667,14 @@ def test_with_benchmark_marker():
26682667
def test_normal():
26692668
pass
26702669
"""
2670+
expected = """import pytest
2671+
2672+
def test_normal():
2673+
pass"""
26712674
tree = ast.parse(source)
26722675
result = self.remover.visit(tree)
26732676

2674-
# Should have import and one function
2675-
functions = [node for node in result.body if isinstance(node, ast.FunctionDef)]
2676-
assert len(functions) == 1
2677-
assert functions[0].name == "test_normal"
2677+
assert ast.unparse(result) == expected
26782678

26792679
def test_removes_function_with_benchmark_decorator_call(self):
26802680
"""Test that functions with @pytest.mark.benchmark() decorator are removed."""
@@ -2689,12 +2689,15 @@ def test_with_benchmark_marker_call():
26892689
def test_normal_with_marker():
26902690
pass
26912691
"""
2692+
expected = """import pytest
2693+
2694+
@pytest.mark.parametrize('x', [1, 2, 3])
2695+
def test_normal_with_marker():
2696+
pass"""
26922697
tree = ast.parse(source)
26932698
result = self.remover.visit(tree)
26942699

2695-
functions = [node for node in result.body if isinstance(node, ast.FunctionDef)]
2696-
assert len(functions) == 1
2697-
assert functions[0].name == "test_normal_with_marker"
2700+
assert ast.unparse(result) == expected
26982701

26992702
def test_removes_function_with_simple_benchmark_decorator(self):
27002703
"""Test that functions with @benchmark decorator are removed."""
@@ -2706,12 +2709,12 @@ def test_simple_benchmark():
27062709
def test_normal():
27072710
pass
27082711
"""
2712+
expected = """def test_normal():
2713+
pass"""
27092714
tree = ast.parse(source)
27102715
result = self.remover.visit(tree)
27112716

2712-
functions = [node for node in result.body if isinstance(node, ast.FunctionDef)]
2713-
assert len(functions) == 1
2714-
assert functions[0].name == "test_normal"
2717+
assert ast.unparse(result) == expected
27152718

27162719
def test_removes_function_with_benchmark_call_in_body(self):
27172720
"""Test that functions calling benchmark() in body are removed."""
@@ -2724,12 +2727,13 @@ def test_normal():
27242727
some_other_function()
27252728
assert True
27262729
"""
2730+
expected = """def test_normal():
2731+
some_other_function()
2732+
assert True"""
27272733
tree = ast.parse(source)
27282734
result = self.remover.visit(tree)
27292735

2730-
functions = [node for node in result.body if isinstance(node, ast.FunctionDef)]
2731-
assert len(functions) == 1
2732-
assert functions[0].name == "test_normal"
2736+
assert ast.unparse(result) == expected
27332737

27342738
def test_removes_benchmark_methods_from_class(self):
27352739
"""Test that benchmark methods are removed from classes."""
@@ -2746,13 +2750,11 @@ def test_benchmark_method(self, benchmark):
27462750
def test_decorated_benchmark(self):
27472751
pass
27482752
"""
2753+
expected = """class TestClass:\n \n def test_normal_method(self):\n assert True"""
27492754
tree = ast.parse(source)
27502755
result = self.remover.visit(tree)
27512756

2752-
class_node = result.body[0]
2753-
methods = [node for node in class_node.body if isinstance(node, ast.FunctionDef)]
2754-
assert len(methods) == 1
2755-
assert methods[0].name == "test_normal_method"
2757+
assert ast.dump(result) == ast.dump(ast.parse(expected))
27562758

27572759
def test_preserves_non_benchmark_functions(self):
27582760
"""Test that non-benchmark functions are preserved."""
@@ -2763,18 +2765,26 @@ def test_normal_function():
27632765
def helper_function(param1, param2):
27642766
return param1 + param2
27652767
2768+
@pytest.mark.parametrize("x", [1, 2, 3])
2769+
def test_parametrized(x):
2770+
assert x > 0
2771+
"""
2772+
expected = """
2773+
def test_normal_function():
2774+
assert True
2775+
2776+
def helper_function(param1, param2):
2777+
return param1 + param2
2778+
27662779
@pytest.mark.parametrize("x", [1, 2, 3])
27672780
def test_parametrized(x):
27682781
assert x > 0
27692782
"""
27702783
tree = ast.parse(source)
2771-
original_functions = [node.name for node in tree.body if isinstance(node, ast.FunctionDef)]
27722784

27732785
result = self.remover.visit(tree)
2774-
result_functions = [node.name for node in result.body if isinstance(node, ast.FunctionDef)]
27752786

2776-
assert len(result_functions) == 3
2777-
assert set(result_functions) == set(original_functions)
2787+
assert ast.dump(result) == ast.dump(ast.parse(expected))
27782788

27792789
def test_handles_empty_class(self):
27802790
"""Test handling of classes that become empty after removing benchmark methods."""
@@ -2784,11 +2794,11 @@ class TestBenchmarks:
27842794
def test_only_benchmark(self):
27852795
pass
27862796
"""
2797+
expected = """class TestBenchmarks:"""
27872798
tree = ast.parse(source)
27882799
result = self.remover.visit(tree)
27892800

2790-
class_node = result.body[0]
2791-
assert len(class_node.body) == 0
2801+
assert ast.unparse(result) == expected
27922802

27932803
def test_handles_mixed_decorators(self):
27942804
"""Test functions with multiple decorators including benchmark."""
@@ -2802,12 +2812,13 @@ def test_multiple_decorators(x):
28022812
def test_normal_with_decorator(y):
28032813
pass
28042814
"""
2815+
expected = """@pytest.mark.parametrize('y', [3, 4])
2816+
def test_normal_with_decorator(y):
2817+
pass"""
28052818
tree = ast.parse(source)
28062819
result = self.remover.visit(tree)
28072820

2808-
functions = [node for node in result.body if isinstance(node, ast.FunctionDef)]
2809-
assert len(functions) == 1
2810-
assert functions[0].name == "test_normal_with_decorator"
2821+
assert ast.unparse(result) == expected
28112822

28122823

28132824
class TestRemoveBenchmarkFunctions:
@@ -2822,13 +2833,15 @@ def test_normal():
28222833
def test_benchmark(benchmark):
28232834
result = benchmark(some_function)
28242835
assert result
2836+
"""
2837+
expected = """
2838+
def test_normal():
2839+
assert True
28252840
"""
28262841
tree = ast.parse(source)
28272842
result = remove_benchmark_functions(tree)
28282843

2829-
functions = [node for node in result.body if isinstance(node, ast.FunctionDef)]
2830-
assert len(functions) == 1
2831-
assert functions[0].name == "test_normal"
2844+
assert ast.dump(result) == ast.dump(ast.parse(expected))
28322845

28332846
def test_remove_benchmark_functions_handles_exception(self, capsys):
28342847
"""Test that exceptions are handled gracefully."""
@@ -2884,30 +2897,24 @@ def standalone_function():
28842897
async def test_async_benchmark():
28852898
await some_async_function()
28862899
"""
2887-
tree = ast.parse(source)
2888-
result = remove_benchmark_functions(tree)
2889-
2890-
# Check that imports and standalone function are preserved
2891-
imports = [node for node in result.body if isinstance(node, (ast.Import, ast.ImportFrom))]
2892-
assert len(imports) == 2
2900+
expected = """
2901+
import pytest
2902+
from some_module import some_function
28932903
2894-
# Check class methods
2895-
class_node = [node for node in result.body if isinstance(node, ast.ClassDef)][0]
2896-
methods = [node for node in class_node.body if isinstance(node, ast.FunctionDef)]
2897-
method_names = [method.name for method in methods]
2898-
assert "setup_method" in method_names
2899-
assert "test_normal_operation" in method_names
2900-
assert "test_benchmark_operation" not in method_names
2901-
assert "test_with_benchmark_param" not in method_names
2904+
class TestPerformance:
2905+
def setup_method(self):
2906+
self.data = [1, 2, 3, 4, 5]
29022907
2903-
# Check standalone function is preserved
2904-
functions = [node for node in result.body if isinstance(node, ast.FunctionDef)]
2905-
assert any(func.name == "standalone_function" for func in functions)
2908+
def test_normal_operation(self):
2909+
assert len(self.data) == 5
29062910
2907-
# Check async benchmark function is removed
2908-
async_functions = [node for node in result.body if isinstance(node, ast.AsyncFunctionDef)]
2909-
assert len(async_functions) == 0
2911+
def standalone_function():
2912+
return "not a test"
2913+
"""
2914+
tree = ast.parse(source)
2915+
result = remove_benchmark_functions(tree)
29102916

2917+
assert ast.dump(result) == ast.dump(ast.parse(expected))
29112918

29122919
class TestBenchmarkDetectionMethods:
29132920
"""Test the individual detection methods."""

0 commit comments

Comments
 (0)