1+ import os
12import re
3+ from pathlib import Path
24
35import libcst as cst
46
57from codeflash .cli_cmds .console import logger
6- from codeflash .code_utils .time_utils import format_time
7- from codeflash .models .models import GeneratedTests , GeneratedTestsList , TestResults
8+ from codeflash .code_utils .time_utils import format_perf , format_time
9+ from codeflash .models .models import GeneratedTests , GeneratedTestsList , InvocationId
10+ from codeflash .result .critic import performance_gain
11+ from codeflash .verification .verification_utils import TestConfig
812
913
1014def remove_functions_from_generated_tests (
@@ -33,40 +37,46 @@ def remove_functions_from_generated_tests(
3337
3438
3539def add_runtime_comments_to_generated_tests (
36- generated_tests : GeneratedTestsList , original_test_results : TestResults , optimized_test_results : TestResults
40+ test_cfg : TestConfig ,
41+ generated_tests : GeneratedTestsList ,
42+ original_runtimes : dict [InvocationId , list [int ]],
43+ optimized_runtimes : dict [InvocationId , list [int ]],
3744) -> GeneratedTestsList :
3845 """Add runtime performance comments to function calls in generated tests."""
39- # Create dictionaries for fast lookup of runtime data
40- original_runtime_by_test = original_test_results . usable_runtime_data_by_test_case ()
41- optimized_runtime_by_test = optimized_test_results . usable_runtime_data_by_test_case ( )
46+ tests_root = test_cfg . tests_root
47+ module_root = test_cfg . project_root_path
48+ rel_tests_root = tests_root . relative_to ( module_root )
4249
50+ # TODO: reduce for loops to one
4351 class RuntimeCommentTransformer (cst .CSTTransformer ):
44- def __init__ (self ) -> None :
45- self .in_test_function = False
46- self .current_test_name : str | None = None
52+ def __init__ (self , test : GeneratedTests , tests_root : Path , rel_tests_root : Path ) -> None :
53+ self .test = test
54+ self .context_stack : list [str ] = []
55+ self .tests_root = tests_root
56+ self .rel_tests_root = rel_tests_root
57+
58+ def visit_ClassDef (self , node : cst .ClassDef ) -> None :
59+ # Track when we enter a class
60+ self .context_stack .append (node .name .value )
61+
62+ def leave_ClassDef (self , original_node : cst .ClassDef , updated_node : cst .ClassDef ) -> cst .ClassDef : # noqa: ARG002
63+ # Pop the context when we leave a class
64+ self .context_stack .pop ()
65+ return updated_node
4766
4867 def visit_FunctionDef (self , node : cst .FunctionDef ) -> None :
49- if node .name .value .startswith ("test_" ):
50- self .in_test_function = True
51- self .current_test_name = node .name .value
52- else :
53- self .in_test_function = False
54- self .current_test_name = None
55-
56- def leave_FunctionDef (self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef ) -> cst .FunctionDef :
57- if original_node .name .value .startswith ("test_" ):
58- self .in_test_function = False
59- self .current_test_name = None
68+ self .context_stack .append (node .name .value )
69+
70+ def leave_FunctionDef (self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef ) -> cst .FunctionDef : # noqa: ARG002
71+ # Pop the context when we leave a function
72+ self .context_stack .pop ()
6073 return updated_node
6174
6275 def leave_SimpleStatementLine (
6376 self ,
6477 original_node : cst .SimpleStatementLine , # noqa: ARG002
6578 updated_node : cst .SimpleStatementLine ,
6679 ) -> cst .SimpleStatementLine :
67- if not self .in_test_function or not self .current_test_name :
68- return updated_node
69-
7080 # Look for assignment statements that assign to codeflash_output
7181 # Handle both single statements and multiple statements on one line
7282 codeflash_assignment_found = False
@@ -83,30 +93,65 @@ def leave_SimpleStatementLine(
8393 # Find matching test cases by looking for this test function name in the test results
8494 matching_original_times = []
8595 matching_optimized_times = []
86-
87- for invocation_id , runtimes in original_runtime_by_test .items ():
88- if invocation_id .test_function_name == self .current_test_name :
96+ # TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name
97+ for invocation_id , runtimes in original_runtimes .items ():
98+ qualified_name = (
99+ invocation_id .test_class_name + "." + invocation_id .test_function_name # type: ignore[operator]
100+ if invocation_id .test_class_name
101+ else invocation_id .test_function_name
102+ )
103+ rel_path = (
104+ Path (invocation_id .test_module_path .replace ("." , os .sep ))
105+ .with_suffix (".py" )
106+ .relative_to (self .rel_tests_root )
107+ )
108+ if qualified_name == "." .join (self .context_stack ) and rel_path in [
109+ self .test .behavior_file_path .relative_to (self .tests_root ),
110+ self .test .perf_file_path .relative_to (self .tests_root ),
111+ ]:
89112 matching_original_times .extend (runtimes )
90113
91- for invocation_id , runtimes in optimized_runtime_by_test .items ():
92- if invocation_id .test_function_name == self .current_test_name :
114+ for invocation_id , runtimes in optimized_runtimes .items ():
115+ qualified_name = (
116+ invocation_id .test_class_name + "." + invocation_id .test_function_name # type: ignore[operator]
117+ if invocation_id .test_class_name
118+ else invocation_id .test_function_name
119+ )
120+ rel_path = (
121+ Path (invocation_id .test_module_path .replace ("." , os .sep ))
122+ .with_suffix (".py" )
123+ .relative_to (self .rel_tests_root )
124+ )
125+ if qualified_name == "." .join (self .context_stack ) and rel_path in [
126+ self .test .behavior_file_path .relative_to (self .tests_root ),
127+ self .test .perf_file_path .relative_to (self .tests_root ),
128+ ]:
93129 matching_optimized_times .extend (runtimes )
94130
95131 if matching_original_times and matching_optimized_times :
96132 original_time = min (matching_original_times )
97133 optimized_time = min (matching_optimized_times )
98-
99- # Create the runtime comment
100- comment_text = f"# { format_time (original_time )} -> { format_time (optimized_time )} "
101-
102- # Add comment to the trailing whitespace
103- new_trailing_whitespace = cst .TrailingWhitespace (
104- whitespace = cst .SimpleWhitespace (" " ),
105- comment = cst .Comment (comment_text ),
106- newline = updated_node .trailing_whitespace .newline ,
107- )
108-
109- return updated_node .with_changes (trailing_whitespace = new_trailing_whitespace )
134+ if original_time != 0 and optimized_time != 0 :
135+ perf_gain = format_perf (
136+ abs (
137+ performance_gain (original_runtime_ns = original_time , optimized_runtime_ns = optimized_time )
138+ * 100
139+ )
140+ )
141+ status = "slower" if optimized_time > original_time else "faster"
142+ # Create the runtime comment
143+ comment_text = (
144+ f"# { format_time (original_time )} -> { format_time (optimized_time )} ({ perf_gain } % { status } )"
145+ )
146+
147+ # Add comment to the trailing whitespace
148+ new_trailing_whitespace = cst .TrailingWhitespace (
149+ whitespace = cst .SimpleWhitespace (" " ),
150+ comment = cst .Comment (comment_text ),
151+ newline = updated_node .trailing_whitespace .newline ,
152+ )
153+
154+ return updated_node .with_changes (trailing_whitespace = new_trailing_whitespace )
110155
111156 return updated_node
112157
@@ -118,7 +163,7 @@ def leave_SimpleStatementLine(
118163 tree = cst .parse_module (test .generated_original_test_source )
119164
120165 # Transform the tree to add runtime comments
121- transformer = RuntimeCommentTransformer ()
166+ transformer = RuntimeCommentTransformer (test , tests_root , rel_tests_root )
122167 modified_tree = tree .visit (transformer )
123168
124169 # Convert back to source code
0 commit comments