Skip to content

Commit 7f9a609

Browse files
committed
implement reranker
1 parent 059b4dc commit 7f9a609

File tree

3 files changed

+68
-66
lines changed

3 files changed

+68
-66
lines changed

codeflash/benchmarking/function_ranker.py

Lines changed: 65 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import TYPE_CHECKING
44

55
from codeflash.cli_cmds.console import logger
6+
from codeflash.code_utils.config_consts import DEFAULT_IMPORTANCE_THRESHOLD
67
from codeflash.tracing.profile_stats import ProfileStats
78

89
if TYPE_CHECKING:
@@ -12,54 +13,45 @@
1213

1314

1415
class FunctionRanker:
15-
"""Ranks functions for optimization based on trace data using ttX scoring.
16+
"""Ranks and filters functions for optimization based on profiling trace data using the ttX scoring method.
1617
17-
ttX = own_time + (time_spent_in_callees x call_count)
18+
The FunctionRanker analyzes function-level timing statistics from a trace file and assigns a ttX score to each function:
1819
19-
This prioritizes functions that:
20-
1. Take significant time themselves (own_time)
21-
2. Are called frequently and have expensive subcalls (time_spent_in_callees x call_count)
20+
ttX = own_time + (time_spent_in_callees x call_count)
21+
22+
This scoring prioritizes functions that:
23+
1. Consume significant time themselves (own_time)
24+
2. Are called frequently and have expensive subcalls (time_spent_in_callees x call_count)
25+
26+
first, filters out functions whose own_time is less than a specified percentage (importance_threshold = minimum fraction of total runtime a function must account for to be considered important) of the total runtime, considering them unimportant for optimization.
27+
28+
The remaining functions are then ranked in descending order by their ttX score, prioritizing those most likely to yield performance improvements if optimized.
2229
"""
2330

2431
def __init__(self, trace_file_path: Path) -> None:
2532
self.trace_file_path = trace_file_path
26-
self._function_stats = None
27-
28-
def load_function_stats(self) -> dict[str, dict]:
29-
"""Load function timing statistics from trace database using ProfileStats."""
30-
if self._function_stats is not None:
31-
return self._function_stats
32-
33-
self._function_stats = {}
33+
self._profile_stats = ProfileStats(trace_file_path.as_posix())
34+
self._function_stats: dict[str, dict] = {}
35+
self.load_function_stats()
3436

37+
def load_function_stats(self) -> None:
3538
try:
36-
profile_stats = ProfileStats(self.trace_file_path.as_posix())
37-
38-
# Access the stats dictionary directly from ProfileStats
39-
for (filename, line_number, function_name), (
39+
for (filename, line_number, func_name), (
4040
call_count,
4141
_num_callers,
4242
total_time_ns,
4343
cumulative_time_ns,
4444
_callers,
45-
) in profile_stats.stats.items():
45+
) in self._profile_stats.stats.items():
4646
if call_count <= 0:
4747
continue
4848

49-
if "." in function_name and not function_name.startswith("<"):
50-
parts = function_name.split(".", 1)
49+
# Parse function name to handle methods within classes
50+
class_name, qualified_name, base_function_name = (None, func_name, func_name)
51+
if "." in func_name and not func_name.startswith("<"):
52+
parts = func_name.split(".", 1)
5153
if len(parts) == 2:
52-
class_name, method_name = parts
53-
qualified_name = function_name
54-
base_function_name = method_name
55-
else:
56-
class_name = None
57-
qualified_name = function_name
58-
base_function_name = function_name
59-
else:
60-
class_name = None
61-
qualified_name = function_name
62-
base_function_name = function_name
54+
class_name, base_function_name = parts
6355

6456
# Calculate own time (total time - time spent in subcalls)
6557
own_time_ns = total_time_ns
@@ -85,54 +77,63 @@ def load_function_stats(self) -> dict[str, dict]:
8577
logger.debug(f"Loaded timing stats for {len(self._function_stats)} functions from trace using ProfileStats")
8678

8779
except Exception as e:
88-
logger.warning(f"Failed to load function stats from trace file {self.trace_file_path}: {e}")
80+
logger.warning(f"Failed to process function stats from trace file {self.trace_file_path}: {e}")
8981
self._function_stats = {}
9082

91-
return self._function_stats
92-
93-
def get_function_ttx_score(self, function_to_optimize: FunctionToOptimize) -> float:
94-
stats = self.load_function_stats()
95-
83+
def _get_function_stats(self, function_to_optimize: FunctionToOptimize) -> dict | None:
9684
possible_keys = [
9785
f"{function_to_optimize.file_path}:{function_to_optimize.qualified_name}",
9886
f"{function_to_optimize.file_path}:{function_to_optimize.function_name}",
9987
]
100-
10188
for key in possible_keys:
102-
if key in stats:
103-
return stats[key]["ttx_score"]
89+
if key in self._function_stats:
90+
return self._function_stats[key]
91+
return None
10492

105-
return 0.0
93+
def get_function_ttx_score(self, function_to_optimize: FunctionToOptimize) -> float:
94+
stats = self._get_function_stats(function_to_optimize)
95+
return stats["ttx_score"] if stats else 0.0
10696

10797
def rank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
108-
# Calculate ttX scores for all functions
109-
function_scores = []
110-
for func in functions_to_optimize:
111-
ttx_score = self.get_function_ttx_score(func)
112-
function_scores.append((func, ttx_score))
98+
return sorted(functions_to_optimize, key=self.get_function_ttx_score, reverse=True)
11399

114-
# Sort by ttX score descending (highest impact first)
115-
function_scores.sort(key=lambda x: x[1], reverse=True)
100+
def get_function_stats_summary(self, function_to_optimize: FunctionToOptimize) -> dict | None:
101+
return self._get_function_stats(function_to_optimize)
116102

117-
# logger.info("Function ranking by ttX score:")
118-
# for i, (func, score) in enumerate(function_scores[:10]): # Top 10
119-
# logger.info(f" {i + 1}. {func.qualified_name} (ttX: {score:.0f}ns)")
103+
def rerank_and_filter_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
104+
"""Reranks and filters functions based on their impact on total runtime.
120105
121-
ranked_functions = [func for func, _ in function_scores]
122-
logger.info(f"Ranked {len(ranked_functions)} functions by optimization priority")
106+
This method first calculates the total runtime of all profiled functions.
107+
It then filters out functions whose own_time is less than a specified
108+
percentage of the total runtime (importance_threshold).
123109
124-
return ranked_functions
110+
The remaining 'important' functions are then ranked by their ttX score.
111+
"""
112+
stats_map = self._function_stats
113+
if not stats_map:
114+
return []
125115

126-
def get_function_stats_summary(self, function_to_optimize: FunctionToOptimize) -> dict | None:
127-
stats = self.load_function_stats()
116+
total_program_time = sum(s["own_time_ns"] for s in stats_map.values() if s.get("own_time_ns", 0) > 0)
128117

129-
possible_keys = [
130-
f"{function_to_optimize.file_path}:{function_to_optimize.qualified_name}",
131-
f"{function_to_optimize.file_path}:{function_to_optimize.function_name}",
132-
]
118+
if total_program_time == 0:
119+
logger.warning("Total program time is zero, cannot determine function importance.")
120+
return self.rank_functions(functions_to_optimize)
133121

134-
for key in possible_keys:
135-
if key in stats:
136-
return stats[key]
122+
important_functions = []
123+
for func in functions_to_optimize:
124+
func_stats = self._get_function_stats(func)
125+
if func_stats and func_stats.get("own_time_ns", 0) > 0:
126+
importance = func_stats["own_time_ns"] / total_program_time
127+
if importance >= DEFAULT_IMPORTANCE_THRESHOLD:
128+
important_functions.append(func)
129+
else:
130+
logger.info(
131+
f"Filtering out function {func.qualified_name} with importance "
132+
f"{importance:.2%} (below threshold {DEFAULT_IMPORTANCE_THRESHOLD:.2%})"
133+
)
137134

138-
return None
135+
logger.info(
136+
f"Filtered down to {len(important_functions)} important functions from {len(functions_to_optimize)} total functions"
137+
)
138+
139+
return self.rank_functions(important_functions)

codeflash/code_utils/config_consts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@
1010
COVERAGE_THRESHOLD = 60.0
1111
MIN_TESTCASE_PASSED_THRESHOLD = 6
1212
REPEAT_OPTIMIZATION_PROBABILITY = 0.1
13+
DEFAULT_IMPORTANCE_THRESHOLD = 0.001

codeflash/discovery/functions_to_optimize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,8 @@ def get_functions_to_optimize(
218218
all_functions.extend(file_functions)
219219

220220
if all_functions:
221-
ranked_functions = ranker.rank_functions(all_functions)
221+
ranked_functions = ranker.rerank_and_filter_functions(all_functions)
222+
functions_count = len(ranked_functions)
222223

223224
ranked_dict = {}
224225
for func in ranked_functions:
@@ -227,7 +228,6 @@ def get_functions_to_optimize(
227228
ranked_dict[func.file_path].append(func)
228229

229230
filtered_modified_functions = ranked_dict
230-
logger.info(f"Ranked {len(all_functions)} functions by optimization priority using trace data")
231231

232232
logger.info(f"Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize")
233233
if optimize_all:

0 commit comments

Comments
 (0)