44import os
55import re
66from pathlib import Path
7- from textwrap import dedent
87from typing import TYPE_CHECKING
98
109import libcst as cst
10+ from libcst import MetadataWrapper
11+ from libcst .metadata import PositionProvider
1112
1213from codeflash .cli_cmds .console import logger
1314from codeflash .code_utils .time_utils import format_perf , format_time
1617
1718if TYPE_CHECKING :
1819 from codeflash .models .models import InvocationId
19- from codeflash .verification .verification_utils import TestConfig
2020
2121
22- def remove_functions_from_generated_tests (
23- generated_tests : GeneratedTestsList , test_functions_to_remove : list [str ]
24- ) -> GeneratedTestsList :
25- new_generated_tests = []
26- for generated_test in generated_tests .generated_tests :
27- for test_function in test_functions_to_remove :
28- function_pattern = re .compile (
29- rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?def\s+{ re .escape (test_function )} \(.*?\):.*?(?=\ndef\s|$)" ,
30- re .DOTALL ,
31- )
32-
33- match = function_pattern .search (generated_test .generated_original_test_source )
34-
35- if match is None or "@pytest.mark.parametrize" in match .group (0 ):
36- continue
37-
38- generated_test .generated_original_test_source = function_pattern .sub (
39- "" , generated_test .generated_original_test_source
40- )
22+ class CommentMapper (ast .NodeVisitor ):
23+ def __init__ (
24+ self , test : GeneratedTests , original_runtimes : dict [str , int ], optimized_runtimes : dict [str , int ]
25+ ) -> None :
26+ self .results : dict [int , str ] = {}
27+ self .test : GeneratedTests = test
28+ self .original_runtimes = original_runtimes
29+ self .optimized_runtimes = optimized_runtimes
30+ self .abs_path = test .behavior_file_path .with_suffix ("" )
31+ self .context_stack : list [str ] = []
4132
42- new_generated_tests .append (generated_test )
33+ def visit_ClassDef (self , node : ast .ClassDef ) -> ast .ClassDef :
34+ self .context_stack .append (node .name )
35+ for inner_node in ast .walk (node ):
36+ if isinstance (inner_node , ast .FunctionDef ):
37+ self .visit_FunctionDef (inner_node )
38+ self .context_stack .pop ()
39+ return node
40+
41+ def get_comment (self , match_key : str ) -> str :
42+ # calculate speedup and output comment
43+ original_time = self .original_runtimes [match_key ]
44+ optimized_time = self .optimized_runtimes [match_key ]
45+ perf_gain = format_perf (
46+ abs (performance_gain (original_runtime_ns = original_time , optimized_runtime_ns = optimized_time ) * 100 )
47+ )
48+ status = "slower" if optimized_time > original_time else "faster"
49+ # Create the runtime comment
50+ return f"# { format_time (original_time )} -> { format_time (optimized_time )} ({ perf_gain } % { status } )"
51+
52+ def visit_FunctionDef (self , node : ast .FunctionDef ) -> ast .FunctionDef :
53+ self .context_stack .append (node .name )
54+ i = len (node .body ) - 1
55+ test_qualified_name = "." .join (self .context_stack )
56+ key = test_qualified_name + "#" + str (self .abs_path )
57+ while i >= 0 :
58+ line_node = node .body [i ]
59+ if isinstance (line_node , (ast .With , ast .For , ast .While , ast .If )):
60+ j = len (line_node .body ) - 1
61+ while j >= 0 :
62+ compound_line_node : ast .stmt = line_node .body [j ]
63+ internal_node : ast .AST
64+ for internal_node in ast .walk (compound_line_node ):
65+ if isinstance (internal_node , (ast .stmt , ast .Assign )):
66+ inv_id = str (i ) + "_" + str (j )
67+ match_key = key + "#" + inv_id
68+ if match_key in self .original_runtimes and match_key in self .optimized_runtimes :
69+ self .results [internal_node .lineno ] = self .get_comment (match_key )
70+ j -= 1
71+ else :
72+ inv_id = str (i )
73+ match_key = key + "#" + inv_id
74+ if match_key in self .original_runtimes and match_key in self .optimized_runtimes :
75+ self .results [line_node .lineno ] = self .get_comment (match_key )
76+ i -= 1
77+ self .context_stack .pop ()
78+ return node
4379
44- return GeneratedTestsList (generated_tests = new_generated_tests )
4580
81+ def get_fn_call_linenos (
82+ test : GeneratedTests , original_runtimes : dict [str , int ], optimized_runtimes : dict [str , int ]
83+ ) -> dict [int , str ]:
84+ line_comment_ast_mapper = CommentMapper (test , original_runtimes , optimized_runtimes )
85+ source_code = test .generated_original_test_source
86+ tree = ast .parse (source_code )
87+ line_comment_ast_mapper .visit (tree )
88+ return line_comment_ast_mapper .results
4689
47- class CfoVisitor (ast .NodeVisitor ):
48- """AST visitor that finds all assignments to a variable named 'codeflash_output'.
4990
50- and reports their location relative to the function they're in.
51- """
91+ class CommentAdder ( cst . CSTTransformer ):
92+ """Transformer that adds comments to specified lines."""
5293
53- def __init__ (self , function_name : str , source_code : str ) -> None :
54- self .source_lines = source_code .splitlines ()
55- self .name = function_name
56- self .results : list [int ] = [] # map actual line number to line number in ast
94+ # Declare metadata dependencies
95+ METADATA_DEPENDENCIES = (PositionProvider ,)
5796
58- def visit_Call (self , node ): # type: ignore[no-untyped-def] # noqa: ANN201, ANN001
59- """Detect fn calls."""
60- func_name = self ._get_called_func_name (node .func ) # type: ignore[no-untyped-call]
61- if func_name == self .name :
62- self .results .append (node .lineno - 1 )
63- self .generic_visit (node )
97+ def __init__ (self , line_to_comments : dict [int , str ]) -> None :
98+ """Initialize the transformer with target line numbers.
6499
65- def _get_called_func_name (self , node ): # type: ignore[no-untyped-def] # noqa: ANN001, ANN202
66- """Return name of called fn."""
67- if isinstance (node , ast .Name ):
68- return node .id
69- if isinstance (node , ast .Attribute ):
70- return node .attr
71- return None
100+ Args:
101+ line_to_comments: Mapping of line numbers (1-indexed) to comments
72102
103+ """
104+ self .line_to_comments = line_to_comments
105+ super ().__init__ ()
73106
74- def find_codeflash_output_assignments ( function_name : str , source_code : str ) -> list [ int ]:
75- tree = ast . parse ( source_code )
76- visitor = CfoVisitor ( function_name , source_code )
77- visitor . visit ( tree )
78- return visitor . results
107+ def leave_SimpleStatementLine (
108+ self , original_node : cst . SimpleStatementLine , updated_node : cst . SimpleStatementLine
109+ ) -> cst . SimpleStatementLine :
110+ """Add comment to simple statement lines."""
111+ pos = self . get_metadata ( PositionProvider , original_node )
79112
113+ if pos and pos .start .line in self .line_to_comments :
114+ # Create a comment with trailing whitespace
115+ comment = cst .TrailingWhitespace (
116+ whitespace = cst .SimpleWhitespace (" " ), comment = cst .Comment (self .line_to_comments [pos .start .line ])
117+ )
80118
81- class Finder (cst .CSTVisitor ):
82- def __init__ (self , name : str ) -> None :
83- super ().__init__ ()
84- self .found = False
85- self .name = name
119+ # Update the trailing whitespace of the line itself
120+ return updated_node .with_changes (trailing_whitespace = comment )
86121
87- def visit_Call (self , call_node ) -> None : # type: ignore[no-untyped-def] # noqa : ANN001
88- func_expr = call_node .func
89- if isinstance (func_expr , cst .Name ):
90- if func_expr .value == self .name :
91- self .found = True
92- elif isinstance (func_expr , cst .Attribute ): # noqa : SIM102
93- if func_expr .attr .value == self .name :
94- self .found = True
122+ return updated_node
95123
124+ def leave_SimpleStatementSuite (
125+ self , original_node : cst .SimpleStatementSuite , updated_node : cst .SimpleStatementSuite
126+ ) -> cst .SimpleStatementSuite :
127+ """Add comment to simple statement suites (e.g., after if/for/while)."""
128+ pos = self .get_metadata (PositionProvider , original_node )
96129
97- # TODO: reduce for loops to one
98- class RuntimeCommentTransformer (cst .CSTTransformer ):
99- def __init__ (
100- self ,
101- qualified_name : str ,
102- module : cst .Module ,
103- test : GeneratedTests ,
104- tests_root : Path ,
105- original_runtimes : dict [InvocationId , list [int ]],
106- optimized_runtimes : dict [InvocationId , list [int ]],
107- ) -> None :
108- super ().__init__ ()
109- self .test = test
110- self .context_stack : list [str ] = []
111- self .tests_root = tests_root
112- self .module = module
113- self .cfo_locs : list [int ] = []
114- self .cfo_idx_loc_to_look_at : int = - 1
115- self .name = qualified_name .split ("." )[- 1 ]
116- self .original_runtimes = original_runtimes
117- self .optimized_runtimes = optimized_runtimes
118-
119- def visit_ClassDef (self , node : cst .ClassDef ) -> None :
120- # Track when we enter a class
121- self .context_stack .append (node .name .value )
130+ if pos and pos .start .line in self .line_to_comments :
131+ # Create a comment with trailing whitespace
132+ comment = cst .TrailingWhitespace (
133+ whitespace = cst .SimpleWhitespace (" " ), comment = cst .Comment (self .line_to_comments [pos .start .line ])
134+ )
122135
123- def leave_ClassDef (self , original_node : cst .ClassDef , updated_node : cst .ClassDef ) -> cst .ClassDef : # noqa: ARG002
124- # Pop the context when we leave a class
125- self .context_stack .pop ()
126- return updated_node
136+ # Update the trailing whitespace of the suite
137+ return updated_node .with_changes (trailing_whitespace = comment )
127138
128- def visit_FunctionDef (self , node : cst .FunctionDef ) -> None :
129- # convert function body to ast normalized string and find occurrences of codeflash_output
130- body_code = dedent (self .module .code_for_node (node .body ))
131- normalized_body_code = ast .unparse (ast .parse (body_code ))
132- self .cfo_locs = sorted (
133- find_codeflash_output_assignments (self .name , normalized_body_code )
134- ) # sorted in order we will encounter them
135- self .cfo_idx_loc_to_look_at = - 1
136- self .context_stack .append (node .name .value )
137-
138- def leave_FunctionDef (self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef ) -> cst .FunctionDef : # noqa: ARG002
139- # Pop the context when we leave a function
140- self .context_stack .pop ()
141139 return updated_node
142140
143- def leave_SimpleStatementLine (
144- self ,
145- original_node : cst .SimpleStatementLine , # noqa: ARG002
146- updated_node : cst .SimpleStatementLine ,
147- ) -> cst .SimpleStatementLine :
148- # Check if this statement line contains a call to self.name
149- if self ._contains_myfunc_call (updated_node ): # type: ignore[no-untyped-call]
150- # Find matching test cases by looking for this test function name in the test results
151- self .cfo_idx_loc_to_look_at += 1
152- matching_original_times = []
153- matching_optimized_times = []
154- # TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name + invocationid
155- for invocation_id , runtimes in self .original_runtimes .items ():
156- # get position here and match in if condition
157- qualified_name = (
158- invocation_id .test_class_name + "." + invocation_id .test_function_name # type: ignore[operator]
159- if invocation_id .test_class_name
160- else invocation_id .test_function_name
161- )
162- abs_path = Path (invocation_id .test_module_path .replace ("." , os .sep )).with_suffix (".py" ).resolve ()
163- if (
164- qualified_name == "." .join (self .context_stack )
165- and abs_path in [self .test .behavior_file_path , self .test .perf_file_path ]
166- and int (invocation_id .iteration_id .split ("_" )[0 ]) == self .cfo_locs [self .cfo_idx_loc_to_look_at ] # type:ignore[union-attr]
167- ):
168- matching_original_times .extend (runtimes )
169-
170- for invocation_id , runtimes in self .optimized_runtimes .items ():
171- # get position here and match in if condition
172- qualified_name = (
173- invocation_id .test_class_name + "." + invocation_id .test_function_name # type: ignore[operator]
174- if invocation_id .test_class_name
175- else invocation_id .test_function_name
176- )
177- abs_path = Path (invocation_id .test_module_path .replace ("." , os .sep )).with_suffix (".py" ).resolve ()
178- if (
179- qualified_name == "." .join (self .context_stack )
180- and abs_path in [self .test .behavior_file_path , self .test .perf_file_path ]
181- and int (invocation_id .iteration_id .split ("_" )[0 ]) == self .cfo_locs [self .cfo_idx_loc_to_look_at ] # type:ignore[union-attr]
182- ):
183- matching_optimized_times .extend (runtimes )
184-
185- if matching_original_times and matching_optimized_times :
186- original_time = min (matching_original_times )
187- optimized_time = min (matching_optimized_times )
188- if original_time != 0 and optimized_time != 0 :
189- perf_gain = format_perf (
190- abs (
191- performance_gain (original_runtime_ns = original_time , optimized_runtime_ns = optimized_time )
192- * 100
193- )
194- )
195- status = "slower" if optimized_time > original_time else "faster"
196- # Create the runtime comment
197- comment_text = (
198- f"# { format_time (original_time )} -> { format_time (optimized_time )} ({ perf_gain } % { status } )"
199- )
200- return updated_node .with_changes (
201- trailing_whitespace = cst .TrailingWhitespace (
202- whitespace = cst .SimpleWhitespace (" " ),
203- comment = cst .Comment (comment_text ),
204- newline = updated_node .trailing_whitespace .newline ,
205- )
206- )
207- return updated_node
208141
209- def _contains_myfunc_call (self , node ): # type: ignore[no-untyped-def] # noqa : ANN202, ANN001
210- """Recursively search for any Call node in the statement whose function is named self.name (including obj.myfunc)."""
211- finder = Finder (self .name )
212- node .visit (finder )
213- return finder .found
142+ def unique_inv_id (inv_id_runtimes : dict [InvocationId , list [int ]]) -> dict [str , int ]:
143+ unique_inv_ids : dict [str , int ] = {}
144+ for inv_id , runtimes in inv_id_runtimes .items ():
145+ test_qualified_name = (
146+ inv_id .test_class_name + "." + inv_id .test_function_name # type: ignore[operator]
147+ if inv_id .test_class_name
148+ else inv_id .test_function_name
149+ )
150+ abs_path = str (Path (inv_id .test_module_path .replace ("." , os .sep )).with_suffix (".py" ).resolve ().with_suffix ("" ))
151+ if "__unit_test_" not in abs_path :
152+ continue
153+ key = test_qualified_name + "#" + abs_path # type: ignore[operator]
154+ parts = inv_id .iteration_id .split ("_" ).__len__ () # type: ignore[union-attr]
155+ cur_invid = inv_id .iteration_id .split ("_" )[0 ] if parts < 3 else "_" .join (inv_id .iteration_id .split ("_" )[:- 1 ]) # type: ignore[union-attr]
156+ match_key = key + "#" + cur_invid
157+ if match_key not in unique_inv_ids :
158+ unique_inv_ids [match_key ] = 0
159+ unique_inv_ids [match_key ] += min (runtimes )
160+ return unique_inv_ids
214161
215162
216163def add_runtime_comments_to_generated_tests (
217- qualified_name : str ,
218- test_cfg : TestConfig ,
219164 generated_tests : GeneratedTestsList ,
220165 original_runtimes : dict [InvocationId , list [int ]],
221166 optimized_runtimes : dict [InvocationId , list [int ]],
222167) -> GeneratedTestsList :
223168 """Add runtime performance comments to function calls in generated tests."""
224- tests_root = test_cfg . tests_root
225-
169+ original_runtimes_dict = unique_inv_id ( original_runtimes )
170+ optimized_runtimes_dict = unique_inv_id ( optimized_runtimes )
226171 # Process each generated test
227172 modified_tests = []
228173 for test in generated_tests .generated_tests :
229174 try :
230- # Parse the test source code
231175 tree = cst .parse_module (test .generated_original_test_source )
232- # Transform the tree to add runtime comments
233- # qualified_name: str, module: cst.Module, test: GeneratedTests, tests_root: Path
234- transformer = RuntimeCommentTransformer (
235- qualified_name , tree , test , tests_root , original_runtimes , optimized_runtimes
236- )
237- modified_tree = tree .visit (transformer )
238-
239- # Convert back to source code
176+ wrapper = MetadataWrapper (tree )
177+ line_to_comments = get_fn_call_linenos (test , original_runtimes_dict , optimized_runtimes_dict )
178+ comment_adder = CommentAdder (line_to_comments )
179+ modified_tree = wrapper .visit (comment_adder )
240180 modified_source = modified_tree .code
241-
242- # Create a new GeneratedTests object with the modified source
243181 modified_test = GeneratedTests (
244182 generated_original_test_source = modified_source ,
245183 instrumented_behavior_test_source = test .instrumented_behavior_test_source ,
@@ -254,3 +192,28 @@ def add_runtime_comments_to_generated_tests(
254192 modified_tests .append (test )
255193
256194 return GeneratedTestsList (generated_tests = modified_tests )
195+
196+
197+ def remove_functions_from_generated_tests (
198+ generated_tests : GeneratedTestsList , test_functions_to_remove : list [str ]
199+ ) -> GeneratedTestsList :
200+ new_generated_tests = []
201+ for generated_test in generated_tests .generated_tests :
202+ for test_function in test_functions_to_remove :
203+ function_pattern = re .compile (
204+ rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?def\s+{ re .escape (test_function )} \(.*?\):.*?(?=\ndef\s|$)" ,
205+ re .DOTALL ,
206+ )
207+
208+ match = function_pattern .search (generated_test .generated_original_test_source )
209+
210+ if match is None or "@pytest.mark.parametrize" in match .group (0 ):
211+ continue
212+
213+ generated_test .generated_original_test_source = function_pattern .sub (
214+ "" , generated_test .generated_original_test_source
215+ )
216+
217+ new_generated_tests .append (generated_test )
218+
219+ return GeneratedTestsList (generated_tests = new_generated_tests )
0 commit comments