|
36 | 36 | N_TESTS_TO_GENERATE, |
37 | 37 | TOTAL_LOOPING_TIME, |
38 | 38 | ) |
39 | | -from codeflash.code_utils.edit_generated_tests import remove_functions_from_generated_tests |
| 39 | +from codeflash.code_utils.edit_generated_tests import ( |
| 40 | + add_runtime_comments_to_generated_tests, |
| 41 | + remove_functions_from_generated_tests, |
| 42 | +) |
40 | 43 | from codeflash.code_utils.formatter import format_code, sort_imports |
41 | 44 | from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test |
42 | 45 | from codeflash.code_utils.line_profile_utils import add_decorator_imports |
@@ -319,7 +322,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 |
319 | 322 | generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove |
320 | 323 | ) |
321 | 324 | # Add runtime comments to generated tests before creating the PR |
322 | | - generated_tests = self.add_runtime_comments_to_generated_tests( |
| 325 | + generated_tests = add_runtime_comments_to_generated_tests( |
323 | 326 | generated_tests, |
324 | 327 | original_code_baseline.benchmarking_test_results, |
325 | 328 | best_optimization.winning_benchmarking_test_results, |
@@ -1270,154 +1273,3 @@ def cleanup_generated_files(self) -> None: |
1270 | 1273 | cleanup_paths(paths_to_cleanup) |
1271 | 1274 | if hasattr(get_run_tmp_file, "tmpdir"): |
1272 | 1275 | get_run_tmp_file.tmpdir.cleanup() |
1273 | | - |
1274 | | - def add_runtime_comments_to_generated_tests( |
1275 | | - self, |
1276 | | - generated_tests: GeneratedTestsList, |
1277 | | - original_test_results: TestResults, |
1278 | | - optimized_test_results: TestResults, |
1279 | | - ) -> GeneratedTestsList: |
1280 | | - """Add runtime performance comments to function calls in generated tests.""" |
1281 | | - |
1282 | | - def format_time(nanoseconds: int) -> str: |
1283 | | - """Format nanoseconds into a human-readable string with 3 significant digits when needed.""" |
1284 | | - |
1285 | | - def count_significant_digits(num: int) -> int: |
1286 | | - """Count significant digits in an integer.""" |
1287 | | - return len(str(abs(num))) |
1288 | | - |
1289 | | - def format_with_precision(value: float, unit: str) -> str: |
1290 | | - """Format a value with 3 significant digits precision.""" |
1291 | | - if value >= 100: |
1292 | | - return f"{value:.0f}{unit}" |
1293 | | - if value >= 10: |
1294 | | - return f"{value:.1f}{unit}" |
1295 | | - return f"{value:.2f}{unit}" |
1296 | | - |
1297 | | - if nanoseconds < 1_000: |
1298 | | - return f"{nanoseconds}ns" |
1299 | | - if nanoseconds < 1_000_000: |
1300 | | - # Convert to microseconds |
1301 | | - microseconds_int = nanoseconds // 1_000 |
1302 | | - if count_significant_digits(microseconds_int) >= 3: |
1303 | | - return f"{microseconds_int}μs" |
1304 | | - microseconds_float = nanoseconds / 1_000 |
1305 | | - return format_with_precision(microseconds_float, "μs") |
1306 | | - if nanoseconds < 1_000_000_000: |
1307 | | - # Convert to milliseconds |
1308 | | - milliseconds_int = nanoseconds // 1_000_000 |
1309 | | - if count_significant_digits(milliseconds_int) >= 3: |
1310 | | - return f"{milliseconds_int}ms" |
1311 | | - milliseconds_float = nanoseconds / 1_000_000 |
1312 | | - return format_with_precision(milliseconds_float, "ms") |
1313 | | - # Convert to seconds |
1314 | | - seconds_int = nanoseconds // 1_000_000_000 |
1315 | | - if count_significant_digits(seconds_int) >= 3: |
1316 | | - return f"{seconds_int}s" |
1317 | | - seconds_float = nanoseconds / 1_000_000_000 |
1318 | | - return format_with_precision(seconds_float, "s") |
1319 | | - |
1320 | | - # Create dictionaries for fast lookup of runtime data |
1321 | | - original_runtime_by_test = original_test_results.usable_runtime_data_by_test_case() |
1322 | | - optimized_runtime_by_test = optimized_test_results.usable_runtime_data_by_test_case() |
1323 | | - |
1324 | | - class RuntimeCommentTransformer(cst.CSTTransformer): |
1325 | | - def __init__(self): |
1326 | | - self.in_test_function = False |
1327 | | - self.current_test_name = None |
1328 | | - |
1329 | | - def visit_FunctionDef(self, node: cst.FunctionDef) -> None: |
1330 | | - if node.name.value.startswith("test_"): |
1331 | | - self.in_test_function = True |
1332 | | - self.current_test_name = node.name.value |
1333 | | - else: |
1334 | | - self.in_test_function = False |
1335 | | - self.current_test_name = None |
1336 | | - |
1337 | | - def leave_FunctionDef( |
1338 | | - self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef |
1339 | | - ) -> cst.FunctionDef: |
1340 | | - if original_node.name.value.startswith("test_"): |
1341 | | - self.in_test_function = False |
1342 | | - self.current_test_name = None |
1343 | | - return updated_node |
1344 | | - |
1345 | | - def leave_SimpleStatementLine( |
1346 | | - self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine |
1347 | | - ) -> cst.SimpleStatementLine: |
1348 | | - if not self.in_test_function or not self.current_test_name: |
1349 | | - return updated_node |
1350 | | - |
1351 | | - # Look for assignment statements that assign to codeflash_output |
1352 | | - # Handle both single statements and multiple statements on one line |
1353 | | - codeflash_assignment_found = False |
1354 | | - for stmt in updated_node.body: |
1355 | | - if isinstance(stmt, cst.Assign): |
1356 | | - if ( |
1357 | | - len(stmt.targets) == 1 |
1358 | | - and isinstance(stmt.targets[0].target, cst.Name) |
1359 | | - and stmt.targets[0].target.value == "codeflash_output" |
1360 | | - ): |
1361 | | - codeflash_assignment_found = True |
1362 | | - break |
1363 | | - |
1364 | | - if codeflash_assignment_found: |
1365 | | - # Find matching test cases by looking for this test function name in the test results |
1366 | | - matching_original_times = [] |
1367 | | - matching_optimized_times = [] |
1368 | | - |
1369 | | - for invocation_id, runtimes in original_runtime_by_test.items(): |
1370 | | - if invocation_id.test_function_name == self.current_test_name: |
1371 | | - matching_original_times.extend(runtimes) |
1372 | | - |
1373 | | - for invocation_id, runtimes in optimized_runtime_by_test.items(): |
1374 | | - if invocation_id.test_function_name == self.current_test_name: |
1375 | | - matching_optimized_times.extend(runtimes) |
1376 | | - |
1377 | | - if matching_original_times and matching_optimized_times: |
1378 | | - original_time = min(matching_original_times) |
1379 | | - optimized_time = min(matching_optimized_times) |
1380 | | - |
1381 | | - # Create the runtime comment |
1382 | | - comment_text = f"# {format_time(original_time)} -> {format_time(optimized_time)}" |
1383 | | - |
1384 | | - # Add comment to the trailing whitespace |
1385 | | - new_trailing_whitespace = cst.TrailingWhitespace( |
1386 | | - whitespace=cst.SimpleWhitespace(" "), |
1387 | | - comment=cst.Comment(comment_text), |
1388 | | - newline=updated_node.trailing_whitespace.newline, |
1389 | | - ) |
1390 | | - |
1391 | | - return updated_node.with_changes(trailing_whitespace=new_trailing_whitespace) |
1392 | | - |
1393 | | - return updated_node |
1394 | | - |
1395 | | - # Process each generated test |
1396 | | - modified_tests = [] |
1397 | | - for test in generated_tests.generated_tests: |
1398 | | - try: |
1399 | | - # Parse the test source code |
1400 | | - tree = cst.parse_module(test.generated_original_test_source) |
1401 | | - |
1402 | | - # Transform the tree to add runtime comments |
1403 | | - transformer = RuntimeCommentTransformer() |
1404 | | - modified_tree = tree.visit(transformer) |
1405 | | - |
1406 | | - # Convert back to source code |
1407 | | - modified_source = modified_tree.code |
1408 | | - |
1409 | | - # Create a new GeneratedTests object with the modified source |
1410 | | - modified_test = GeneratedTests( |
1411 | | - generated_original_test_source=modified_source, |
1412 | | - instrumented_behavior_test_source=test.instrumented_behavior_test_source, |
1413 | | - instrumented_perf_test_source=test.instrumented_perf_test_source, |
1414 | | - behavior_file_path=test.behavior_file_path, |
1415 | | - perf_file_path=test.perf_file_path, |
1416 | | - ) |
1417 | | - modified_tests.append(modified_test) |
1418 | | - except Exception as e: |
1419 | | - # If parsing fails, keep the original test |
1420 | | - logger.debug(f"Failed to add runtime comments to test: {e}") |
1421 | | - modified_tests.append(test) |
1422 | | - |
1423 | | - return GeneratedTestsList(generated_tests=modified_tests) |
0 commit comments