@@ -2630,13 +2630,12 @@ def test_performance(benchmark):
26302630def test_normal():
26312631 assert True
26322632"""
2633+ expected = """def test_normal():
2634+ assert True"""
26332635 tree = ast .parse (source )
26342636 result = self .remover .visit (tree )
26352637
2636- # Should only have one function left
2637- functions = [node for node in result .body if isinstance (node , ast .FunctionDef )]
2638- assert len (functions ) == 1
2639- assert functions [0 ].name == "test_normal"
2638+ assert ast .unparse (result ) == expected
26402639
26412640 def test_removes_async_function_with_benchmark_parameter (self ):
26422641 """Test that async functions with 'benchmark' parameter are removed."""
@@ -2648,13 +2647,13 @@ async def test_async_performance(benchmark):
26482647async def test_async_normal():
26492648 assert True
26502649"""
2650+ expected = """async def test_async_normal():
2651+ assert True"""
26512652 tree = ast .parse (source )
26522653 result = self .remover .visit (tree )
26532654
26542655 # Should only have one async function left
2655- functions = [node for node in result .body if isinstance (node , ast .AsyncFunctionDef )]
2656- assert len (functions ) == 1
2657- assert functions [0 ].name == "test_async_normal"
2656+ assert ast .unparse (result ) == expected
26582657
26592658 def test_removes_function_with_pytest_mark_benchmark_decorator (self ):
26602659 """Test that functions with @pytest.mark.benchmark decorator are removed."""
@@ -2668,13 +2667,14 @@ def test_with_benchmark_marker():
26682667def test_normal():
26692668 pass
26702669"""
2670+ expected = """import pytest
2671+
2672+ def test_normal():
2673+ pass"""
26712674 tree = ast .parse (source )
26722675 result = self .remover .visit (tree )
26732676
2674- # Should have import and one function
2675- functions = [node for node in result .body if isinstance (node , ast .FunctionDef )]
2676- assert len (functions ) == 1
2677- assert functions [0 ].name == "test_normal"
2677+ assert ast .unparse (result ) == expected
26782678
26792679 def test_removes_function_with_benchmark_decorator_call (self ):
26802680 """Test that functions with @pytest.mark.benchmark() decorator are removed."""
@@ -2689,12 +2689,15 @@ def test_with_benchmark_marker_call():
26892689def test_normal_with_marker():
26902690 pass
26912691"""
2692+ expected = """import pytest
2693+
2694+ @pytest.mark.parametrize('x', [1, 2, 3])
2695+ def test_normal_with_marker():
2696+ pass"""
26922697 tree = ast .parse (source )
26932698 result = self .remover .visit (tree )
26942699
2695- functions = [node for node in result .body if isinstance (node , ast .FunctionDef )]
2696- assert len (functions ) == 1
2697- assert functions [0 ].name == "test_normal_with_marker"
2700+ assert ast .unparse (result ) == expected
26982701
26992702 def test_removes_function_with_simple_benchmark_decorator (self ):
27002703 """Test that functions with @benchmark decorator are removed."""
@@ -2706,12 +2709,12 @@ def test_simple_benchmark():
27062709def test_normal():
27072710 pass
27082711"""
2712+ expected = """def test_normal():
2713+ pass"""
27092714 tree = ast .parse (source )
27102715 result = self .remover .visit (tree )
27112716
2712- functions = [node for node in result .body if isinstance (node , ast .FunctionDef )]
2713- assert len (functions ) == 1
2714- assert functions [0 ].name == "test_normal"
2717+ assert ast .unparse (result ) == expected
27152718
27162719 def test_removes_function_with_benchmark_call_in_body (self ):
27172720 """Test that functions calling benchmark() in body are removed."""
@@ -2724,12 +2727,13 @@ def test_normal():
27242727 some_other_function()
27252728 assert True
27262729"""
2730+ expected = """def test_normal():
2731+ some_other_function()
2732+ assert True"""
27272733 tree = ast .parse (source )
27282734 result = self .remover .visit (tree )
27292735
2730- functions = [node for node in result .body if isinstance (node , ast .FunctionDef )]
2731- assert len (functions ) == 1
2732- assert functions [0 ].name == "test_normal"
2736+ assert ast .unparse (result ) == expected
27332737
27342738 def test_removes_benchmark_methods_from_class (self ):
27352739 """Test that benchmark methods are removed from classes."""
@@ -2746,13 +2750,11 @@ def test_benchmark_method(self, benchmark):
27462750 def test_decorated_benchmark(self):
27472751 pass
27482752"""
2753+ expected = """class TestClass:\n \n def test_normal_method(self):\n assert True"""
27492754 tree = ast .parse (source )
27502755 result = self .remover .visit (tree )
27512756
2752- class_node = result .body [0 ]
2753- methods = [node for node in class_node .body if isinstance (node , ast .FunctionDef )]
2754- assert len (methods ) == 1
2755- assert methods [0 ].name == "test_normal_method"
2757+ assert ast .dump (result ) == ast .dump (ast .parse (expected ))
27562758
27572759 def test_preserves_non_benchmark_functions (self ):
27582760 """Test that non-benchmark functions are preserved."""
@@ -2763,18 +2765,26 @@ def test_normal_function():
27632765def helper_function(param1, param2):
27642766 return param1 + param2
27652767
2768+ @pytest.mark.parametrize("x", [1, 2, 3])
2769+ def test_parametrized(x):
2770+ assert x > 0
2771+ """
2772+ expected = """
2773+ def test_normal_function():
2774+ assert True
2775+
2776+ def helper_function(param1, param2):
2777+ return param1 + param2
2778+
27662779@pytest.mark.parametrize("x", [1, 2, 3])
27672780def test_parametrized(x):
27682781 assert x > 0
27692782"""
27702783 tree = ast .parse (source )
2771- original_functions = [node .name for node in tree .body if isinstance (node , ast .FunctionDef )]
27722784
27732785 result = self .remover .visit (tree )
2774- result_functions = [node .name for node in result .body if isinstance (node , ast .FunctionDef )]
27752786
2776- assert len (result_functions ) == 3
2777- assert set (result_functions ) == set (original_functions )
2787+ assert ast .dump (result ) == ast .dump (ast .parse (expected ))
27782788
27792789 def test_handles_empty_class (self ):
27802790 """Test handling of classes that become empty after removing benchmark methods."""
@@ -2784,11 +2794,11 @@ class TestBenchmarks:
27842794 def test_only_benchmark(self):
27852795 pass
27862796"""
2797+ expected = """class TestBenchmarks:"""
27872798 tree = ast .parse (source )
27882799 result = self .remover .visit (tree )
27892800
2790- class_node = result .body [0 ]
2791- assert len (class_node .body ) == 0
2801+ assert ast .unparse (result ) == expected
27922802
27932803 def test_handles_mixed_decorators (self ):
27942804 """Test functions with multiple decorators including benchmark."""
@@ -2802,12 +2812,13 @@ def test_multiple_decorators(x):
28022812def test_normal_with_decorator(y):
28032813 pass
28042814"""
2815+ expected = """@pytest.mark.parametrize('y', [3, 4])
2816+ def test_normal_with_decorator(y):
2817+ pass"""
28052818 tree = ast .parse (source )
28062819 result = self .remover .visit (tree )
28072820
2808- functions = [node for node in result .body if isinstance (node , ast .FunctionDef )]
2809- assert len (functions ) == 1
2810- assert functions [0 ].name == "test_normal_with_decorator"
2821+ assert ast .unparse (result ) == expected
28112822
28122823
28132824class TestRemoveBenchmarkFunctions :
@@ -2822,13 +2833,15 @@ def test_normal():
28222833def test_benchmark(benchmark):
28232834 result = benchmark(some_function)
28242835 assert result
2836+ """
2837+ expected = """
2838+ def test_normal():
2839+ assert True
28252840"""
28262841 tree = ast .parse (source )
28272842 result = remove_benchmark_functions (tree )
28282843
2829- functions = [node for node in result .body if isinstance (node , ast .FunctionDef )]
2830- assert len (functions ) == 1
2831- assert functions [0 ].name == "test_normal"
2844+ assert ast .dump (result ) == ast .dump (ast .parse (expected ))
28322845
28332846 def test_remove_benchmark_functions_handles_exception (self , capsys ):
28342847 """Test that exceptions are handled gracefully."""
@@ -2884,30 +2897,24 @@ def standalone_function():
28842897async def test_async_benchmark():
28852898 await some_async_function()
28862899"""
2887- tree = ast .parse (source )
2888- result = remove_benchmark_functions (tree )
2889-
2890- # Check that imports and standalone function are preserved
2891- imports = [node for node in result .body if isinstance (node , (ast .Import , ast .ImportFrom ))]
2892- assert len (imports ) == 2
2900+ expected = """
2901+ import pytest
2902+ from some_module import some_function
28932903
2894- # Check class methods
2895- class_node = [node for node in result .body if isinstance (node , ast .ClassDef )][0 ]
2896- methods = [node for node in class_node .body if isinstance (node , ast .FunctionDef )]
2897- method_names = [method .name for method in methods ]
2898- assert "setup_method" in method_names
2899- assert "test_normal_operation" in method_names
2900- assert "test_benchmark_operation" not in method_names
2901- assert "test_with_benchmark_param" not in method_names
2904+ class TestPerformance:
2905+ def setup_method(self):
2906+ self.data = [1, 2, 3, 4, 5]
29022907
2903- # Check standalone function is preserved
2904- functions = [node for node in result .body if isinstance (node , ast .FunctionDef )]
2905- assert any (func .name == "standalone_function" for func in functions )
2908+ def test_normal_operation(self):
2909+ assert len(self.data) == 5
29062910
2907- # Check async benchmark function is removed
2908- async_functions = [node for node in result .body if isinstance (node , ast .AsyncFunctionDef )]
2909- assert len (async_functions ) == 0
2911+ def standalone_function():
2912+ return "not a test"
2913+ """
2914+ tree = ast .parse (source )
2915+ result = remove_benchmark_functions (tree )
29102916
2917+ assert ast .dump (result ) == ast .dump (ast .parse (expected ))
29112918
29122919class TestBenchmarkDetectionMethods :
29132920 """Test the individual detection methods."""
0 commit comments