33from typing import TYPE_CHECKING
44
55from codeflash .cli_cmds .console import logger
6+ from codeflash .code_utils .config_consts import DEFAULT_IMPORTANCE_THRESHOLD
67from codeflash .tracing .profile_stats import ProfileStats
78
89if TYPE_CHECKING :
1213
1314
1415class FunctionRanker :
15- """Ranks functions for optimization based on trace data using ttX scoring.
16+ """Ranks and filters functions for optimization based on profiling trace data using the ttX scoring method .
1617
17- ttX = own_time + (time_spent_in_callees x call_count)
18+ The FunctionRanker analyzes function-level timing statistics from a trace file and assigns a ttX score to each function:
1819
19- This prioritizes functions that:
20- 1. Take significant time themselves (own_time)
21- 2. Are called frequently and have expensive subcalls (time_spent_in_callees x call_count)
20+ ttX = own_time + (time_spent_in_callees x call_count)
21+
22+ This scoring prioritizes functions that:
23+ 1. Consume significant time themselves (own_time)
24+ 2. Are called frequently and have expensive subcalls (time_spent_in_callees x call_count)
25+
26+ first, filters out functions whose own_time is less than a specified percentage (importance_threshold = minimum fraction of total runtime a function must account for to be considered important) of the total runtime, considering them unimportant for optimization.
27+
28+ The remaining functions are then ranked in descending order by their ttX score, prioritizing those most likely to yield performance improvements if optimized.
2229 """
2330
2431 def __init__ (self , trace_file_path : Path ) -> None :
2532 self .trace_file_path = trace_file_path
26- self ._function_stats = None
27-
28- def load_function_stats (self ) -> dict [str , dict ]:
29- """Load function timing statistics from trace database using ProfileStats."""
30- if self ._function_stats is not None :
31- return self ._function_stats
32-
33- self ._function_stats = {}
33+ self ._profile_stats = ProfileStats (trace_file_path .as_posix ())
34+ self ._function_stats : dict [str , dict ] = {}
35+ self .load_function_stats ()
3436
37+ def load_function_stats (self ) -> None :
3538 try :
36- profile_stats = ProfileStats (self .trace_file_path .as_posix ())
37-
38- # Access the stats dictionary directly from ProfileStats
39- for (filename , line_number , function_name ), (
39+ for (filename , line_number , func_name ), (
4040 call_count ,
4141 _num_callers ,
4242 total_time_ns ,
4343 cumulative_time_ns ,
4444 _callers ,
45- ) in profile_stats .stats .items ():
45+ ) in self . _profile_stats .stats .items ():
4646 if call_count <= 0 :
4747 continue
4848
49- if "." in function_name and not function_name .startswith ("<" ):
50- parts = function_name .split ("." , 1 )
49+ # Parse function name to handle methods within classes
50+ class_name , qualified_name , base_function_name = (None , func_name , func_name )
51+ if "." in func_name and not func_name .startswith ("<" ):
52+ parts = func_name .split ("." , 1 )
5153 if len (parts ) == 2 :
52- class_name , method_name = parts
53- qualified_name = function_name
54- base_function_name = method_name
55- else :
56- class_name = None
57- qualified_name = function_name
58- base_function_name = function_name
59- else :
60- class_name = None
61- qualified_name = function_name
62- base_function_name = function_name
54+ class_name , base_function_name = parts
6355
6456 # Calculate own time (total time - time spent in subcalls)
6557 own_time_ns = total_time_ns
@@ -85,54 +77,63 @@ def load_function_stats(self) -> dict[str, dict]:
8577 logger .debug (f"Loaded timing stats for { len (self ._function_stats )} functions from trace using ProfileStats" )
8678
8779 except Exception as e :
88- logger .warning (f"Failed to load function stats from trace file { self .trace_file_path } : { e } " )
80+ logger .warning (f"Failed to process function stats from trace file { self .trace_file_path } : { e } " )
8981 self ._function_stats = {}
9082
91- return self ._function_stats
92-
93- def get_function_ttx_score (self , function_to_optimize : FunctionToOptimize ) -> float :
94- stats = self .load_function_stats ()
95-
83+ def _get_function_stats (self , function_to_optimize : FunctionToOptimize ) -> dict | None :
9684 possible_keys = [
9785 f"{ function_to_optimize .file_path } :{ function_to_optimize .qualified_name } " ,
9886 f"{ function_to_optimize .file_path } :{ function_to_optimize .function_name } " ,
9987 ]
100-
10188 for key in possible_keys :
102- if key in stats :
103- return stats [key ]["ttx_score" ]
89+ if key in self ._function_stats :
90+ return self ._function_stats [key ]
91+ return None
10492
105- return 0.0
93+ def get_function_ttx_score (self , function_to_optimize : FunctionToOptimize ) -> float :
94+ stats = self ._get_function_stats (function_to_optimize )
95+ return stats ["ttx_score" ] if stats else 0.0
10696
10797 def rank_functions (self , functions_to_optimize : list [FunctionToOptimize ]) -> list [FunctionToOptimize ]:
108- # Calculate ttX scores for all functions
109- function_scores = []
110- for func in functions_to_optimize :
111- ttx_score = self .get_function_ttx_score (func )
112- function_scores .append ((func , ttx_score ))
98+ return sorted (functions_to_optimize , key = self .get_function_ttx_score , reverse = True )
11399
114- # Sort by ttX score descending (highest impact first)
115- function_scores . sort ( key = lambda x : x [ 1 ], reverse = True )
100+ def get_function_stats_summary ( self , function_to_optimize : FunctionToOptimize ) -> dict | None :
101+ return self . _get_function_stats ( function_to_optimize )
116102
117- # logger.info("Function ranking by ttX score:")
118- # for i, (func, score) in enumerate(function_scores[:10]): # Top 10
119- # logger.info(f" {i + 1}. {func.qualified_name} (ttX: {score:.0f}ns)")
103+ def rerank_and_filter_functions (self , functions_to_optimize : list [FunctionToOptimize ]) -> list [FunctionToOptimize ]:
104+ """Reranks and filters functions based on their impact on total runtime.
120105
121- ranked_functions = [func for func , _ in function_scores ]
122- logger .info (f"Ranked { len (ranked_functions )} functions by optimization priority" )
106+ This method first calculates the total runtime of all profiled functions.
107+ It then filters out functions whose own_time is less than a specified
108+ percentage of the total runtime (importance_threshold).
123109
124- return ranked_functions
110+ The remaining 'important' functions are then ranked by their ttX score.
111+ """
112+ stats_map = self ._function_stats
113+ if not stats_map :
114+ return []
125115
126- def get_function_stats_summary (self , function_to_optimize : FunctionToOptimize ) -> dict | None :
127- stats = self .load_function_stats ()
116+ total_program_time = sum (s ["own_time_ns" ] for s in stats_map .values () if s .get ("own_time_ns" , 0 ) > 0 )
128117
129- possible_keys = [
130- f"{ function_to_optimize .file_path } :{ function_to_optimize .qualified_name } " ,
131- f"{ function_to_optimize .file_path } :{ function_to_optimize .function_name } " ,
132- ]
118+ if total_program_time == 0 :
119+ logger .warning ("Total program time is zero, cannot determine function importance." )
120+ return self .rank_functions (functions_to_optimize )
133121
134- for key in possible_keys :
135- if key in stats :
136- return stats [key ]
122+ important_functions = []
123+ for func in functions_to_optimize :
124+ func_stats = self ._get_function_stats (func )
125+ if func_stats and func_stats .get ("own_time_ns" , 0 ) > 0 :
126+ importance = func_stats ["own_time_ns" ] / total_program_time
127+ if importance >= DEFAULT_IMPORTANCE_THRESHOLD :
128+ important_functions .append (func )
129+ else :
130+ logger .info (
131+ f"Filtering out function { func .qualified_name } with importance "
132+ f"{ importance :.2%} (below threshold { DEFAULT_IMPORTANCE_THRESHOLD :.2%} )"
133+ )
137134
138- return None
135+ logger .info (
136+ f"Filtered down to { len (important_functions )} important functions from { len (functions_to_optimize )} total functions"
137+ )
138+
139+ return self .rank_functions (important_functions )
0 commit comments