44import os
55import re
66from pathlib import Path
7- from typing import TYPE_CHECKING
7+ from typing import TYPE_CHECKING , Optional
88
99import libcst as cst
1010from libcst import MetadataWrapper
@@ -149,18 +149,19 @@ def leave_SimpleStatementSuite(
149149 return updated_node
150150
151151
152- def unique_inv_id (inv_id_runtimes : dict [InvocationId , list [int ]]) -> dict [str , int ]:
152+ def unique_inv_id (inv_id_runtimes : dict [InvocationId , list [int ]], tests_project_rootdir : Path ) -> dict [str , int ]:
153153 unique_inv_ids : dict [str , int ] = {}
154154 for inv_id , runtimes in inv_id_runtimes .items ():
155155 test_qualified_name = (
156156 inv_id .test_class_name + "." + inv_id .test_function_name # type: ignore[operator]
157157 if inv_id .test_class_name
158158 else inv_id .test_function_name
159159 )
160- abs_path = str (Path (inv_id .test_module_path .replace ("." , os .sep )).with_suffix (".py" ).resolve ().with_suffix ("" ))
161- if "__unit_test_" not in abs_path :
160+ abs_path = tests_project_rootdir / Path (inv_id .test_module_path .replace ("." , os .sep )).with_suffix (".py" )
161+ abs_path_str = str (abs_path .resolve ().with_suffix ("" ))
162+ if "__unit_test_" not in abs_path_str or not test_qualified_name :
162163 continue
163- key = test_qualified_name + "#" + abs_path # type: ignore[operator]
164+ key = test_qualified_name + "#" + abs_path_str
164165 parts = inv_id .iteration_id .split ("_" ).__len__ () # type: ignore[union-attr]
165166 cur_invid = inv_id .iteration_id .split ("_" )[0 ] if parts < 3 else "_" .join (inv_id .iteration_id .split ("_" )[:- 1 ]) # type: ignore[union-attr]
166167 match_key = key + "#" + cur_invid
@@ -174,10 +175,11 @@ def add_runtime_comments_to_generated_tests(
174175 generated_tests : GeneratedTestsList ,
175176 original_runtimes : dict [InvocationId , list [int ]],
176177 optimized_runtimes : dict [InvocationId , list [int ]],
178+ tests_project_rootdir : Optional [Path ] = None ,
177179) -> GeneratedTestsList :
178180 """Add runtime performance comments to function calls in generated tests."""
179- original_runtimes_dict = unique_inv_id (original_runtimes )
180- optimized_runtimes_dict = unique_inv_id (optimized_runtimes )
181+ original_runtimes_dict = unique_inv_id (original_runtimes , tests_project_rootdir or Path () )
182+ optimized_runtimes_dict = unique_inv_id (optimized_runtimes , tests_project_rootdir or Path () )
181183 # Process each generated test
182184 modified_tests = []
183185 for test in generated_tests .generated_tests :
0 commit comments