3636 N_TESTS_TO_GENERATE ,
3737 TOTAL_LOOPING_TIME ,
3838)
39+ from codeflash .code_utils .edit_generated_tests import remove_functions_from_generated_tests
3940from codeflash .code_utils .formatter import format_code , sort_imports
4041from codeflash .code_utils .instrument_existing_tests import inject_profiling_into_existing_test
4142from codeflash .code_utils .line_profile_utils import add_decorator_imports
42- from codeflash .code_utils .remove_generated_tests import remove_functions_from_generated_tests
4343from codeflash .code_utils .static_analysis import get_first_top_level_function_or_method_ast
4444from codeflash .code_utils .time_utils import humanize_runtime
4545from codeflash .context import code_context_extractor
@@ -265,10 +265,6 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
265265 },
266266 )
267267
268- generated_tests = remove_functions_from_generated_tests (
269- generated_tests = generated_tests , test_functions_to_remove = test_functions_to_remove
270- )
271-
272268 if best_optimization :
273269 logger .info ("Best candidate:" )
274270 code_print (best_optimization .candidate .source_code )
@@ -295,8 +291,6 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
295291 benchmark_details = processed_benchmark_info .benchmark_details if processed_benchmark_info else None ,
296292 )
297293
298- self .log_successful_optimization (explanation , generated_tests , exp_type )
299-
300294 self .replace_function_and_helpers_with_optimized_code (
301295 code_context = code_context , optimized_code = best_optimization .candidate .source_code
302296 )
@@ -321,6 +315,15 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
321315 if original_code_baseline .coverage_results
322316 else "Coverage data not available"
323317 )
318+ generated_tests = remove_functions_from_generated_tests (
319+ generated_tests = generated_tests , test_functions_to_remove = test_functions_to_remove
320+ )
321+ # Add runtime comments to generated tests before creating the PR
322+ generated_tests = self .add_runtime_comments_to_generated_tests (
323+ generated_tests ,
324+ original_code_baseline .benchmarking_test_results ,
325+ best_optimization .winning_benchmarking_test_results ,
326+ )
324327 generated_tests_str = "\n \n " .join (
325328 [test .generated_original_test_source for test in generated_tests .generated_tests ]
326329 )
@@ -345,6 +348,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
345348 original_helper_code ,
346349 self .function_to_optimize .file_path ,
347350 )
351+ self .log_successful_optimization (explanation , generated_tests , exp_type )
348352
349353 if not best_optimization :
350354 return Failure (f"No best optimizations found for function { self .function_to_optimize .qualified_name } " )
@@ -1266,3 +1270,154 @@ def cleanup_generated_files(self) -> None:
12661270 cleanup_paths (paths_to_cleanup )
12671271 if hasattr (get_run_tmp_file , "tmpdir" ):
12681272 get_run_tmp_file .tmpdir .cleanup ()
1273+
1274+ def add_runtime_comments_to_generated_tests (
1275+ self ,
1276+ generated_tests : GeneratedTestsList ,
1277+ original_test_results : TestResults ,
1278+ optimized_test_results : TestResults ,
1279+ ) -> GeneratedTestsList :
1280+ """Add runtime performance comments to function calls in generated tests."""
1281+
1282+ def format_time (nanoseconds : int ) -> str :
1283+ """Format nanoseconds into a human-readable string with 3 significant digits when needed."""
1284+
1285+ def count_significant_digits (num : int ) -> int :
1286+ """Count significant digits in an integer."""
1287+ return len (str (abs (num )))
1288+
1289+ def format_with_precision (value : float , unit : str ) -> str :
1290+ """Format a value with 3 significant digits precision."""
1291+ if value >= 100 :
1292+ return f"{ value :.0f} { unit } "
1293+ if value >= 10 :
1294+ return f"{ value :.1f} { unit } "
1295+ return f"{ value :.2f} { unit } "
1296+
1297+ if nanoseconds < 1_000 :
1298+ return f"{ nanoseconds } ns"
1299+ if nanoseconds < 1_000_000 :
1300+ # Convert to microseconds
1301+ microseconds_int = nanoseconds // 1_000
1302+ if count_significant_digits (microseconds_int ) >= 3 :
1303+ return f"{ microseconds_int } μs"
1304+ microseconds_float = nanoseconds / 1_000
1305+ return format_with_precision (microseconds_float , "μs" )
1306+ if nanoseconds < 1_000_000_000 :
1307+ # Convert to milliseconds
1308+ milliseconds_int = nanoseconds // 1_000_000
1309+ if count_significant_digits (milliseconds_int ) >= 3 :
1310+ return f"{ milliseconds_int } ms"
1311+ milliseconds_float = nanoseconds / 1_000_000
1312+ return format_with_precision (milliseconds_float , "ms" )
1313+ # Convert to seconds
1314+ seconds_int = nanoseconds // 1_000_000_000
1315+ if count_significant_digits (seconds_int ) >= 3 :
1316+ return f"{ seconds_int } s"
1317+ seconds_float = nanoseconds / 1_000_000_000
1318+ return format_with_precision (seconds_float , "s" )
1319+
1320+ # Create dictionaries for fast lookup of runtime data
1321+ original_runtime_by_test = original_test_results .usable_runtime_data_by_test_case ()
1322+ optimized_runtime_by_test = optimized_test_results .usable_runtime_data_by_test_case ()
1323+
1324+ class RuntimeCommentTransformer (cst .CSTTransformer ):
1325+ def __init__ (self ):
1326+ self .in_test_function = False
1327+ self .current_test_name = None
1328+
1329+ def visit_FunctionDef (self , node : cst .FunctionDef ) -> None :
1330+ if node .name .value .startswith ("test_" ):
1331+ self .in_test_function = True
1332+ self .current_test_name = node .name .value
1333+ else :
1334+ self .in_test_function = False
1335+ self .current_test_name = None
1336+
1337+ def leave_FunctionDef (
1338+ self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef
1339+ ) -> cst .FunctionDef :
1340+ if original_node .name .value .startswith ("test_" ):
1341+ self .in_test_function = False
1342+ self .current_test_name = None
1343+ return updated_node
1344+
1345+ def leave_SimpleStatementLine (
1346+ self , original_node : cst .SimpleStatementLine , updated_node : cst .SimpleStatementLine
1347+ ) -> cst .SimpleStatementLine :
1348+ if not self .in_test_function or not self .current_test_name :
1349+ return updated_node
1350+
1351+ # Look for assignment statements that assign to codeflash_output
1352+ # Handle both single statements and multiple statements on one line
1353+ codeflash_assignment_found = False
1354+ for stmt in updated_node .body :
1355+ if isinstance (stmt , cst .Assign ):
1356+ if (
1357+ len (stmt .targets ) == 1
1358+ and isinstance (stmt .targets [0 ].target , cst .Name )
1359+ and stmt .targets [0 ].target .value == "codeflash_output"
1360+ ):
1361+ codeflash_assignment_found = True
1362+ break
1363+
1364+ if codeflash_assignment_found :
1365+ # Find matching test cases by looking for this test function name in the test results
1366+ matching_original_times = []
1367+ matching_optimized_times = []
1368+
1369+ for invocation_id , runtimes in original_runtime_by_test .items ():
1370+ if invocation_id .test_function_name == self .current_test_name :
1371+ matching_original_times .extend (runtimes )
1372+
1373+ for invocation_id , runtimes in optimized_runtime_by_test .items ():
1374+ if invocation_id .test_function_name == self .current_test_name :
1375+ matching_optimized_times .extend (runtimes )
1376+
1377+ if matching_original_times and matching_optimized_times :
1378+ original_time = min (matching_original_times )
1379+ optimized_time = min (matching_optimized_times )
1380+
1381+ # Create the runtime comment
1382+ comment_text = f"# { format_time (original_time )} -> { format_time (optimized_time )} "
1383+
1384+ # Add comment to the trailing whitespace
1385+ new_trailing_whitespace = cst .TrailingWhitespace (
1386+ whitespace = cst .SimpleWhitespace (" " ),
1387+ comment = cst .Comment (comment_text ),
1388+ newline = updated_node .trailing_whitespace .newline ,
1389+ )
1390+
1391+ return updated_node .with_changes (trailing_whitespace = new_trailing_whitespace )
1392+
1393+ return updated_node
1394+
1395+ # Process each generated test
1396+ modified_tests = []
1397+ for test in generated_tests .generated_tests :
1398+ try :
1399+ # Parse the test source code
1400+ tree = cst .parse_module (test .generated_original_test_source )
1401+
1402+ # Transform the tree to add runtime comments
1403+ transformer = RuntimeCommentTransformer ()
1404+ modified_tree = tree .visit (transformer )
1405+
1406+ # Convert back to source code
1407+ modified_source = modified_tree .code
1408+
1409+ # Create a new GeneratedTests object with the modified source
1410+ modified_test = GeneratedTests (
1411+ generated_original_test_source = modified_source ,
1412+ instrumented_behavior_test_source = test .instrumented_behavior_test_source ,
1413+ instrumented_perf_test_source = test .instrumented_perf_test_source ,
1414+ behavior_file_path = test .behavior_file_path ,
1415+ perf_file_path = test .perf_file_path ,
1416+ )
1417+ modified_tests .append (modified_test )
1418+ except Exception as e :
1419+ # If parsing fails, keep the original test
1420+ logger .debug (f"Failed to add runtime comments to test: { e } " )
1421+ modified_tests .append (test )
1422+
1423+ return GeneratedTestsList (generated_tests = modified_tests )
0 commit comments