|
31 | 31 | if TYPE_CHECKING: |
32 | 32 | from argparse import Namespace |
33 | 33 |
|
| 34 | + from codeflash.benchmarking.function_ranker import FunctionRanker |
34 | 35 | from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint |
35 | 36 | from codeflash.discovery.functions_to_optimize import FunctionToOptimize |
36 | 37 | from codeflash.models.models import BenchmarkKey, FunctionCalledInTest |
@@ -251,6 +252,143 @@ def discover_tests( |
251 | 252 | ph("cli-optimize-discovered-tests", {"num_tests": num_discovered_tests}) |
252 | 253 | return function_to_tests, num_discovered_tests |
253 | 254 |
|
| 255 | + def display_global_ranking( |
| 256 | + self, globally_ranked: list[tuple[Path, FunctionToOptimize]], ranker: FunctionRanker, show_top_n: int = 15 |
| 257 | + ) -> None: |
| 258 | + from rich.table import Table |
| 259 | + |
| 260 | + if not globally_ranked: |
| 261 | + return |
| 262 | + |
| 263 | + # Show top N functions |
| 264 | + display_count = min(show_top_n, len(globally_ranked)) |
| 265 | + |
| 266 | + table = Table( |
| 267 | + title=f"Function Ranking (Top {display_count} of {len(globally_ranked)})", |
| 268 | + title_style="bold cyan", |
| 269 | + border_style="cyan", |
| 270 | + show_lines=False, |
| 271 | + ) |
| 272 | + |
| 273 | + table.add_column("Priority", style="bold yellow", justify="center", width=8) |
| 274 | + table.add_column("Function", style="cyan", width=40) |
| 275 | + table.add_column("File", style="dim", width=25) |
| 276 | + table.add_column("ttX Score", justify="right", style="green", width=12) |
| 277 | + table.add_column("Impact", justify="center", style="bold", width=8) |
| 278 | + |
| 279 | + # Get ttX scores for display |
| 280 | + for i, (file_path, func) in enumerate(globally_ranked[:display_count], 1): |
| 281 | + ttx_score = ranker.get_function_ttx_score(func) |
| 282 | + |
| 283 | + # Format function name |
| 284 | + func_name = func.qualified_name |
| 285 | + if len(func_name) > 38: |
| 286 | + func_name = func_name[:35] + "..." |
| 287 | + |
| 288 | + # Format file name |
| 289 | + file_name = file_path.name |
| 290 | + if len(file_name) > 23: |
| 291 | + file_name = "..." + file_name[-20:] |
| 292 | + |
| 293 | + # Format ttX score |
| 294 | + if ttx_score >= 1e9: |
| 295 | + ttx_display = f"{ttx_score / 1e9:.2f}s" |
| 296 | + elif ttx_score >= 1e6: |
| 297 | + ttx_display = f"{ttx_score / 1e6:.1f}ms" |
| 298 | + elif ttx_score >= 1e3: |
| 299 | + ttx_display = f"{ttx_score / 1e3:.1f}µs" |
| 300 | + else: |
| 301 | + ttx_display = f"{ttx_score:.0f}ns" |
| 302 | + |
| 303 | + # Impact indicator |
| 304 | + if i <= 5: |
| 305 | + impact = "🔥" |
| 306 | + impact_style = "bold red" |
| 307 | + elif i <= 10: |
| 308 | + impact = "⚡" |
| 309 | + impact_style = "bold yellow" |
| 310 | + else: |
| 311 | + impact = "💡" |
| 312 | + impact_style = "bold blue" |
| 313 | + |
| 314 | + table.add_row(f"#{i}", func_name, file_name, ttx_display, impact, style=impact_style if i <= 5 else None) |
| 315 | + |
| 316 | + console.print(table) |
| 317 | + |
| 318 | + if len(globally_ranked) > display_count: |
| 319 | + console.print(f"[dim]... and {len(globally_ranked) - display_count} more functions[/dim]") |
| 320 | + |
| 321 | + def rank_all_functions_globally( |
| 322 | + self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], trace_file_path: Path | None |
| 323 | + ) -> list[tuple[Path, FunctionToOptimize]]: |
| 324 | + """Rank all functions globally across all files based on trace data. |
| 325 | +
|
| 326 | + This performs global ranking instead of per-file ranking, ensuring that |
| 327 | + high-impact functions are optimized first regardless of which file they're in. |
| 328 | +
|
| 329 | + Args: |
| 330 | + file_to_funcs_to_optimize: Mapping of file paths to functions to optimize |
| 331 | + trace_file_path: Path to trace file with performance data |
| 332 | +
|
| 333 | + Returns: |
| 334 | + List of (file_path, function) tuples in globally ranked order by ttX score. |
| 335 | + If no trace file or ranking fails, returns functions in original file order. |
| 336 | +
|
| 337 | + """ |
| 338 | + all_functions: list[tuple[Path, FunctionToOptimize]] = [] |
| 339 | + for file_path, functions in file_to_funcs_to_optimize.items(): |
| 340 | + all_functions.extend((file_path, func) for func in functions) |
| 341 | + |
| 342 | + # If no trace file, return in original order |
| 343 | + if not trace_file_path or not trace_file_path.exists(): |
| 344 | + logger.debug("No trace file available, using original function order") |
| 345 | + return all_functions |
| 346 | + |
| 347 | + try: |
| 348 | + from codeflash.benchmarking.function_ranker import FunctionRanker |
| 349 | + |
| 350 | + console.rule() |
| 351 | + logger.info("loading|Ranking functions globally by performance impact...") |
| 352 | + console.rule() |
| 353 | + # Create ranker with trace data |
| 354 | + ranker = FunctionRanker(trace_file_path) |
| 355 | + |
| 356 | + # Extract just the functions for ranking (without file paths) |
| 357 | + functions_only = [func for _, func in all_functions] |
| 358 | + |
| 359 | + # Rank globally |
| 360 | + ranked_functions = ranker.rank_functions(functions_only) |
| 361 | + |
| 362 | + # Reconstruct with file paths by looking up original file for each ranked function |
| 363 | + # Build reverse mapping: function -> file path |
| 364 | + # Since FunctionToOptimize is unhashable (contains list), we compare by identity |
| 365 | + func_to_file_map = {} |
| 366 | + for file_path, func in all_functions: |
| 367 | + # Use a tuple of unique identifiers as the key |
| 368 | + key: tuple[Path, str, int | None] = (func.file_path, func.qualified_name, func.starting_line) |
| 369 | + func_to_file_map[key] = file_path |
| 370 | + globally_ranked = [] |
| 371 | + for func in ranked_functions: |
| 372 | + key = (func.file_path, func.qualified_name, func.starting_line) |
| 373 | + file_path = func_to_file_map.get(key) |
| 374 | + if file_path: |
| 375 | + globally_ranked.append((file_path, func)) |
| 376 | + console.rule() |
| 377 | + logger.info( |
| 378 | + f"Globally ranked {len(ranked_functions)} functions by ttX score " |
| 379 | + f"(filtered {len(functions_only) - len(ranked_functions)} low-importance functions)" |
| 380 | + ) |
| 381 | + |
| 382 | + self.display_global_ranking(globally_ranked, ranker) |
| 383 | + console.rule() |
| 384 | + |
| 385 | + return globally_ranked |
| 386 | + |
| 387 | + except Exception as e: |
| 388 | + logger.warning(f"Could not perform global ranking: {e}") |
| 389 | + logger.debug("Falling back to original function order") |
| 390 | + return all_functions |
| 391 | + |
254 | 392 | def run(self) -> None: |
255 | 393 | from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint |
256 | 394 |
|
@@ -297,84 +435,77 @@ def run(self) -> None: |
297 | 435 | if self.args.all: |
298 | 436 | self.functions_checkpoint = CodeflashRunCheckpoint(self.args.module_root) |
299 | 437 |
|
300 | | - for original_module_path in file_to_funcs_to_optimize: |
301 | | - module_prep_result = self.prepare_module_for_optimization(original_module_path) |
302 | | - if module_prep_result is None: |
303 | | - continue |
| 438 | + # GLOBAL RANKING: Rank all functions together before optimizing |
| 439 | + globally_ranked_functions = self.rank_all_functions_globally(file_to_funcs_to_optimize, trace_file_path) |
| 440 | + # Cache for module preparation (avoid re-parsing same files) |
| 441 | + prepared_modules: dict[Path, tuple[dict[Path, ValidCode], ast.Module]] = {} |
304 | 442 |
|
305 | | - validated_original_code, original_module_ast = module_prep_result |
| 443 | + # Optimize functions in globally ranked order |
| 444 | + for i, (original_module_path, function_to_optimize) in enumerate(globally_ranked_functions): |
| 445 | + # Prepare module if not already cached |
| 446 | + if original_module_path not in prepared_modules: |
| 447 | + module_prep_result = self.prepare_module_for_optimization(original_module_path) |
| 448 | + if module_prep_result is None: |
| 449 | + logger.warning(f"Skipping functions in {original_module_path} due to preparation error") |
| 450 | + continue |
| 451 | + prepared_modules[original_module_path] = module_prep_result |
306 | 452 |
|
307 | | - functions_to_optimize = file_to_funcs_to_optimize[original_module_path] |
308 | | - if trace_file_path and trace_file_path.exists() and len(functions_to_optimize) > 1: |
309 | | - try: |
310 | | - from codeflash.benchmarking.function_ranker import FunctionRanker |
| 453 | + validated_original_code, original_module_ast = prepared_modules[original_module_path] |
311 | 454 |
|
312 | | - ranker = FunctionRanker(trace_file_path) |
313 | | - functions_to_optimize = ranker.rank_functions(functions_to_optimize) |
314 | | - logger.info( |
315 | | - f"Ranked {len(functions_to_optimize)} functions by performance impact in {original_module_path}" |
316 | | - ) |
317 | | - console.rule() |
318 | | - except Exception as e: |
319 | | - logger.debug(f"Could not rank functions in {original_module_path}: {e}") |
320 | | - |
321 | | - for i, function_to_optimize in enumerate(functions_to_optimize): |
322 | | - function_iterator_count = i + 1 |
323 | | - logger.info( |
324 | | - f"Optimizing function {function_iterator_count} of {num_optimizable_functions}: " |
325 | | - f"{function_to_optimize.qualified_name}" |
| 455 | + function_iterator_count = i + 1 |
| 456 | + logger.info( |
| 457 | + f"Optimizing function {function_iterator_count} of {len(globally_ranked_functions)}: " |
| 458 | + f"{function_to_optimize.qualified_name} (in {original_module_path.name})" |
| 459 | + ) |
| 460 | + console.rule() |
| 461 | + function_optimizer = None |
| 462 | + try: |
| 463 | + function_optimizer = self.create_function_optimizer( |
| 464 | + function_to_optimize, |
| 465 | + function_to_tests=function_to_tests, |
| 466 | + function_to_optimize_source_code=validated_original_code[original_module_path].source_code, |
| 467 | + function_benchmark_timings=function_benchmark_timings, |
| 468 | + total_benchmark_timings=total_benchmark_timings, |
| 469 | + original_module_ast=original_module_ast, |
| 470 | + original_module_path=original_module_path, |
326 | 471 | ) |
327 | | - console.rule() |
328 | | - function_optimizer = None |
329 | | - try: |
330 | | - function_optimizer = self.create_function_optimizer( |
331 | | - function_to_optimize, |
332 | | - function_to_tests=function_to_tests, |
333 | | - function_to_optimize_source_code=validated_original_code[original_module_path].source_code, |
334 | | - function_benchmark_timings=function_benchmark_timings, |
335 | | - total_benchmark_timings=total_benchmark_timings, |
336 | | - original_module_ast=original_module_ast, |
337 | | - original_module_path=original_module_path, |
338 | | - ) |
339 | | - if function_optimizer is None: |
340 | | - continue |
| 472 | + if function_optimizer is None: |
| 473 | + continue |
341 | 474 |
|
342 | | - self.current_function_optimizer = ( |
343 | | - function_optimizer # needed to clean up from the outside of this function |
| 475 | + self.current_function_optimizer = ( |
| 476 | + function_optimizer # needed to clean up from the outside of this function |
| 477 | + ) |
| 478 | + best_optimization = function_optimizer.optimize_function() |
| 479 | + if self.functions_checkpoint: |
| 480 | + self.functions_checkpoint.add_function_to_checkpoint( |
| 481 | + function_to_optimize.qualified_name_with_modules_from_root(self.args.project_root) |
344 | 482 | ) |
345 | | - best_optimization = function_optimizer.optimize_function() |
346 | | - if self.functions_checkpoint: |
347 | | - self.functions_checkpoint.add_function_to_checkpoint( |
348 | | - function_to_optimize.qualified_name_with_modules_from_root(self.args.project_root) |
| 483 | + if is_successful(best_optimization): |
| 484 | + optimizations_found += 1 |
| 485 | + # create a diff patch for successful optimization |
| 486 | + if self.current_worktree: |
| 487 | + best_opt = best_optimization.unwrap() |
| 488 | + read_writable_code = best_opt.code_context.read_writable_code |
| 489 | + relative_file_paths = [ |
| 490 | + code_string.file_path for code_string in read_writable_code.code_strings |
| 491 | + ] |
| 492 | + patch_path = create_diff_patch_from_worktree( |
| 493 | + self.current_worktree, relative_file_paths, fto_name=function_to_optimize.qualified_name |
349 | 494 | ) |
350 | | - if is_successful(best_optimization): |
351 | | - optimizations_found += 1 |
352 | | - # create a diff patch for successful optimization |
353 | | - if self.current_worktree: |
354 | | - best_opt = best_optimization.unwrap() |
355 | | - read_writable_code = best_opt.code_context.read_writable_code |
356 | | - relative_file_paths = [ |
357 | | - code_string.file_path for code_string in read_writable_code.code_strings |
358 | | - ] |
359 | | - patch_path = create_diff_patch_from_worktree( |
360 | | - self.current_worktree, |
361 | | - relative_file_paths, |
362 | | - fto_name=function_to_optimize.qualified_name, |
| 495 | + self.patch_files.append(patch_path) |
| 496 | + if i < len(globally_ranked_functions) - 1: |
| 497 | + next_file, next_func = globally_ranked_functions[i + 1] |
| 498 | + create_worktree_snapshot_commit( |
| 499 | + self.current_worktree, f"Optimizing {next_func.qualified_name}" |
363 | 500 | ) |
364 | | - self.patch_files.append(patch_path) |
365 | | - if i < len(functions_to_optimize) - 1: |
366 | | - create_worktree_snapshot_commit( |
367 | | - self.current_worktree, |
368 | | - f"Optimizing {functions_to_optimize[i + 1].qualified_name}", |
369 | | - ) |
370 | | - else: |
371 | | - logger.warning(best_optimization.failure()) |
372 | | - console.rule() |
373 | | - continue |
374 | | - finally: |
375 | | - if function_optimizer is not None: |
376 | | - function_optimizer.executor.shutdown(wait=True) |
377 | | - function_optimizer.cleanup_generated_files() |
| 501 | + else: |
| 502 | + logger.warning(best_optimization.failure()) |
| 503 | + console.rule() |
| 504 | + continue |
| 505 | + finally: |
| 506 | + if function_optimizer is not None: |
| 507 | + function_optimizer.executor.shutdown(wait=True) |
| 508 | + function_optimizer.cleanup_generated_files() |
378 | 509 |
|
379 | 510 | ph("cli-optimize-run-finished", {"optimizations_found": optimizations_found}) |
380 | 511 | if len(self.patch_files) > 0: |
|
0 commit comments