1+ from __future__ import annotations
2+
3+ import ast
14import os
25import re
36from pathlib import Path
7+ from textwrap import dedent
8+ from typing import TYPE_CHECKING , Union
49
510import libcst as cst
611
712from codeflash .cli_cmds .console import logger
813from codeflash .code_utils .time_utils import format_perf , format_time
9- from codeflash .models .models import GeneratedTests , GeneratedTestsList , InvocationId
14+ from codeflash .models .models import GeneratedTests , GeneratedTestsList
1015from codeflash .result .critic import performance_gain
11- from codeflash .verification .verification_utils import TestConfig
16+
17+ if TYPE_CHECKING :
18+ from codeflash .models .models import InvocationId
19+ from codeflash .verification .verification_utils import TestConfig
1220
1321
1422def remove_functions_from_generated_tests (
@@ -36,6 +44,94 @@ def remove_functions_from_generated_tests(
3644 return GeneratedTestsList (generated_tests = new_generated_tests )
3745
3846
47+ class CfoVisitor (ast .NodeVisitor ):
48+ """AST visitor that finds all assignments to a variable named 'codeflash_output'.
49+
50+ and reports their location relative to the function they're in.
51+ """
52+
53+ def __init__ (self , source_code : str ) -> None :
54+ self .source_lines = source_code .splitlines ()
55+ self .results : list [int ] = [] # map actual line number to line number in ast
56+
57+ def _is_codeflash_output_target (self , target : Union [ast .expr , list ]) -> bool : # type: ignore[type-arg]
58+ """Check if the assignment target is the variable 'codeflash_output'."""
59+ if isinstance (target , ast .Name ):
60+ return target .id == "codeflash_output"
61+ if isinstance (target , (ast .Tuple , ast .List )):
62+ # Handle tuple/list unpacking: a, codeflash_output, b = values
63+ return any (self ._is_codeflash_output_target (elt ) for elt in target .elts )
64+ if isinstance (target , (ast .Subscript , ast .Attribute )):
65+ # Not a simple variable assignment
66+ return False
67+ return False
68+
69+ def _record_assignment (self , node : ast .AST ) -> None :
70+ """Record an assignment to codeflash_output."""
71+ relative_line = node .lineno - 1 # type: ignore[attr-defined]
72+ self .results .append (relative_line )
73+
74+ def visit_Assign (self , node : ast .Assign ) -> None :
75+ """Visit assignment statements: codeflash_output = value."""
76+ for target in node .targets :
77+ if self ._is_codeflash_output_target (target ):
78+ self ._record_assignment (node )
79+ break
80+ self .generic_visit (node )
81+
82+ def visit_AnnAssign (self , node : ast .AnnAssign ) -> None :
83+ """Visit annotated assignments: codeflash_output: int = value."""
84+ if self ._is_codeflash_output_target (node .target ):
85+ self ._record_assignment (node )
86+ self .generic_visit (node )
87+
88+ def visit_AugAssign (self , node : ast .AugAssign ) -> None :
89+ """Visit augmented assignments: codeflash_output += value."""
90+ if self ._is_codeflash_output_target (node .target ):
91+ self ._record_assignment (node )
92+ self .generic_visit (node )
93+
94+ def visit_NamedExpr (self , node : ast .NamedExpr ) -> None :
95+ """Visit walrus operator: (codeflash_output := value)."""
96+ if isinstance (node .target , ast .Name ) and node .target .id == "codeflash_output" :
97+ self ._record_assignment (node )
98+ self .generic_visit (node )
99+
100+ def visit_For (self , node : ast .For ) -> None :
101+ """Visit for loops: for codeflash_output in iterable."""
102+ if self ._is_codeflash_output_target (node .target ):
103+ self ._record_assignment (node )
104+ self .generic_visit (node )
105+
106+ def visit_comprehension (self , node : ast .comprehension ) -> None :
107+ """Visit comprehensions: [x for codeflash_output in iterable]."""
108+ if self ._is_codeflash_output_target (node .target ):
109+ # Comprehensions don't have line numbers, so we skip recording
110+ pass
111+ self .generic_visit (node )
112+
113+ def visit_With (self , node : ast .With ) -> None :
114+ """Visit with statements: with expr as codeflash_output."""
115+ for item in node .items :
116+ if item .optional_vars and self ._is_codeflash_output_target (item .optional_vars ):
117+ self ._record_assignment (node )
118+ break
119+ self .generic_visit (node )
120+
121+ def visit_ExceptHandler (self , node : ast .ExceptHandler ) -> None :
122+ """Visit except handlers: except Exception as codeflash_output."""
123+ if node .name == "codeflash_output" :
124+ self ._record_assignment (node )
125+ self .generic_visit (node )
126+
127+
128+ def find_codeflash_output_assignments (source_code : str ) -> list [int ]:
129+ tree = ast .parse (source_code )
130+ visitor = CfoVisitor (source_code )
131+ visitor .visit (tree )
132+ return visitor .results
133+
134+
39135def add_runtime_comments_to_generated_tests (
40136 test_cfg : TestConfig ,
41137 generated_tests : GeneratedTestsList ,
@@ -49,11 +145,15 @@ def add_runtime_comments_to_generated_tests(
49145
50146 # TODO: reduce for loops to one
51147 class RuntimeCommentTransformer (cst .CSTTransformer ):
52- def __init__ (self , test : GeneratedTests , tests_root : Path , rel_tests_root : Path ) -> None :
148+ def __init__ (self , module : cst .Module , test : GeneratedTests , tests_root : Path , rel_tests_root : Path ) -> None :
149+ super ().__init__ ()
53150 self .test = test
54151 self .context_stack : list [str ] = []
55152 self .tests_root = tests_root
56153 self .rel_tests_root = rel_tests_root
154+ self .module = module
155+ self .cfo_locs : list [int ] = []
156+ self .cfo_idx_loc_to_look_at : int = - 1
57157
58158 def visit_ClassDef (self , node : cst .ClassDef ) -> None :
59159 # Track when we enter a class
@@ -65,6 +165,13 @@ def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef
65165 return updated_node
66166
67167 def visit_FunctionDef (self , node : cst .FunctionDef ) -> None :
168+ # convert function body to ast normalized string and find occurrences of codeflash_output
169+ body_code = dedent (self .module .code_for_node (node .body ))
170+ normalized_body_code = ast .unparse (ast .parse (body_code ))
171+ self .cfo_locs = sorted (
172+ find_codeflash_output_assignments (normalized_body_code )
173+ ) # sorted in order we will encounter them
174+ self .cfo_idx_loc_to_look_at = - 1
68175 self .context_stack .append (node .name .value )
69176
70177 def leave_FunctionDef (self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef ) -> cst .FunctionDef : # noqa: ARG002
@@ -91,10 +198,12 @@ def leave_SimpleStatementLine(
91198
92199 if codeflash_assignment_found :
93200 # Find matching test cases by looking for this test function name in the test results
201+ self .cfo_idx_loc_to_look_at += 1
94202 matching_original_times = []
95203 matching_optimized_times = []
96- # TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name
204+ # TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name + invocationid
97205 for invocation_id , runtimes in original_runtimes .items ():
206+ # get position here and match in if condition
98207 qualified_name = (
99208 invocation_id .test_class_name + "." + invocation_id .test_function_name # type: ignore[operator]
100209 if invocation_id .test_class_name
@@ -105,13 +214,19 @@ def leave_SimpleStatementLine(
105214 .with_suffix (".py" )
106215 .relative_to (self .rel_tests_root )
107216 )
108- if qualified_name == "." .join (self .context_stack ) and rel_path in [
109- self .test .behavior_file_path .relative_to (self .tests_root ),
110- self .test .perf_file_path .relative_to (self .tests_root ),
111- ]:
217+ if (
218+ qualified_name == "." .join (self .context_stack )
219+ and rel_path
220+ in [
221+ self .test .behavior_file_path .relative_to (self .tests_root ),
222+ self .test .perf_file_path .relative_to (self .tests_root ),
223+ ]
224+ and int (invocation_id .iteration_id .split ("_" )[0 ]) == self .cfo_locs [self .cfo_idx_loc_to_look_at ] # type:ignore[union-attr]
225+ ):
112226 matching_original_times .extend (runtimes )
113227
114228 for invocation_id , runtimes in optimized_runtimes .items ():
229+ # get position here and match in if condition
115230 qualified_name = (
116231 invocation_id .test_class_name + "." + invocation_id .test_function_name # type: ignore[operator]
117232 if invocation_id .test_class_name
@@ -122,10 +237,15 @@ def leave_SimpleStatementLine(
122237 .with_suffix (".py" )
123238 .relative_to (self .rel_tests_root )
124239 )
125- if qualified_name == "." .join (self .context_stack ) and rel_path in [
126- self .test .behavior_file_path .relative_to (self .tests_root ),
127- self .test .perf_file_path .relative_to (self .tests_root ),
128- ]:
240+ if (
241+ qualified_name == "." .join (self .context_stack )
242+ and rel_path
243+ in [
244+ self .test .behavior_file_path .relative_to (self .tests_root ),
245+ self .test .perf_file_path .relative_to (self .tests_root ),
246+ ]
247+ and int (invocation_id .iteration_id .split ("_" )[0 ]) == self .cfo_locs [self .cfo_idx_loc_to_look_at ] # type:ignore[union-attr]
248+ ):
129249 matching_optimized_times .extend (runtimes )
130250
131251 if matching_original_times and matching_optimized_times :
@@ -161,9 +281,8 @@ def leave_SimpleStatementLine(
161281 try :
162282 # Parse the test source code
163283 tree = cst .parse_module (test .generated_original_test_source )
164-
165284 # Transform the tree to add runtime comments
166- transformer = RuntimeCommentTransformer (test , tests_root , rel_tests_root )
285+ transformer = RuntimeCommentTransformer (tree , test , tests_root , rel_tests_root )
167286 modified_tree = tree .visit (transformer )
168287
169288 # Convert back to source code
0 commit comments