1+ import os
12import re
3+ from pathlib import Path
24
35import libcst as cst
46
57from codeflash .cli_cmds .console import logger
68from codeflash .code_utils .time_utils import format_time
79from codeflash .models .models import GeneratedTests , GeneratedTestsList
10+ from codeflash .verification .verification_utils import TestConfig
811
912
1013def remove_functions_from_generated_tests (
@@ -33,38 +36,43 @@ def remove_functions_from_generated_tests(
3336
3437
3538def add_runtime_comments_to_generated_tests (
36- generated_tests : GeneratedTestsList , original_runtimes : dict , optimized_runtimes : dict
39+ test_cfg : TestConfig , generated_tests : GeneratedTestsList , original_runtimes : dict , optimized_runtimes : dict
3740) -> GeneratedTestsList :
3841 """Add runtime performance comments to function calls in generated tests."""
42+ tests_root = test_cfg .tests_root
43+ module_root = test_cfg .project_root_path
44+ rel_tests_root = tests_root .relative_to (module_root )
3945
4046 # TODO: reduce for loops to one
4147 class RuntimeCommentTransformer (cst .CSTTransformer ):
42- def __init__ (self ) -> None :
43- self .in_test_function = False
44- self .current_test_name : str | None = None
48+ def __init__ (self , test : GeneratedTests , tests_root : Path , rel_tests_root : Path ) -> None :
49+ self .test = test
50+ self .context_stack = []
51+ self .tests_root = tests_root
52+ self .rel_tests_root = rel_tests_root
53+
54+ def visit_ClassDef (self , node : cst .ClassDef ) -> None :
55+ # Track when we enter a class
56+ self .context_stack .append (node .name .value )
57+
58+ def leave_ClassDef (self , original_node : cst .ClassDef , updated_node : cst .ClassDef ) -> cst .ClassDef : # noqa: ARG002
59+ # Pop the context when we leave a class
60+ self .context_stack .pop ()
61+ return updated_node
4562
4663 def visit_FunctionDef (self , node : cst .FunctionDef ) -> None :
47- if node .name .value .startswith ("test_" ):
48- self .in_test_function = True
49- self .current_test_name = node .name .value
50- else :
51- self .in_test_function = False
52- self .current_test_name = None
53-
54- def leave_FunctionDef (self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef ) -> cst .FunctionDef :
55- if original_node .name .value .startswith ("test_" ):
56- self .in_test_function = False
57- self .current_test_name = None
64+ self .context_stack .append (node .name .value )
65+
66+ def leave_FunctionDef (self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef ) -> cst .FunctionDef : # noqa: ARG002
67+ # Pop the context when we leave a function
68+ self .context_stack .pop ()
5869 return updated_node
5970
6071 def leave_SimpleStatementLine (
6172 self ,
6273 original_node : cst .SimpleStatementLine , # noqa: ARG002
6374 updated_node : cst .SimpleStatementLine ,
6475 ) -> cst .SimpleStatementLine :
65- if not self .in_test_function or not self .current_test_name :
66- return updated_node
67-
6876 # Look for assignment statements that assign to codeflash_output
6977 # Handle both single statements and multiple statements on one line
7078 codeflash_assignment_found = False
@@ -83,11 +91,37 @@ def leave_SimpleStatementLine(
8391 matching_optimized_times = []
8492 # TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name
8593 for invocation_id , runtimes in original_runtimes .items ():
86- if invocation_id .test_function_name == self .current_test_name :
94+ qualified_name = (
95+ invocation_id .test_class_name + "." + invocation_id .test_function_name
96+ if invocation_id .test_class_name
97+ else invocation_id .test_function_name
98+ )
99+ rel_path = (
100+ Path (invocation_id .test_module_path .replace ("." , os .sep ))
101+ .with_suffix (".py" )
102+ .relative_to (self .rel_tests_root )
103+ )
104+ if qualified_name == "." .join (self .context_stack ) and rel_path in [
105+ self .test .behavior_file_path .relative_to (self .tests_root ),
106+ self .test .perf_file_path .relative_to (self .tests_root ),
107+ ]:
87108 matching_original_times .extend (runtimes )
88109
89110 for invocation_id , runtimes in optimized_runtimes .items ():
90- if invocation_id .test_function_name == self .current_test_name :
111+ qualified_name = (
112+ invocation_id .test_class_name + "." + invocation_id .test_function_name
113+ if invocation_id .test_class_name
114+ else invocation_id .test_function_name
115+ )
116+ rel_path = (
117+ Path (invocation_id .test_module_path .replace ("." , os .sep ))
118+ .with_suffix (".py" )
119+ .relative_to (self .rel_tests_root )
120+ )
121+ if qualified_name == "." .join (self .context_stack ) and rel_path in [
122+ self .test .behavior_file_path .relative_to (self .tests_root ),
123+ self .test .perf_file_path .relative_to (self .tests_root ),
124+ ]:
91125 matching_optimized_times .extend (runtimes )
92126
93127 if matching_original_times and matching_optimized_times :
@@ -116,7 +150,7 @@ def leave_SimpleStatementLine(
116150 tree = cst .parse_module (test .generated_original_test_source )
117151
118152 # Transform the tree to add runtime comments
119- transformer = RuntimeCommentTransformer ()
153+ transformer = RuntimeCommentTransformer (test , tests_root , rel_tests_root )
120154 modified_tree = tree .visit (transformer )
121155
122156 # Convert back to source code
0 commit comments