Skip to content

Commit eb9e0c6

Browse files
committed
allow predict to be included
1 parent 7f9a609 commit eb9e0c6

File tree

3 files changed

+1011
-537
lines changed

3 files changed

+1011
-537
lines changed

code_to_optimize/code_directories/simple_tracer_e2e/workload.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from concurrent.futures import ThreadPoolExecutor
2+
from time import sleep
23

34

45
def funcA(number):
@@ -46,12 +47,20 @@ def _classify(self, features):
4647
class SimpleModel:
4748
@staticmethod
4849
def predict(data):
49-
return [x * 2 for x in data]
50+
result = []
51+
sleep(10)
52+
for i in range(500):
53+
for x in data:
54+
computation = 0
55+
computation += x * i ** 2
56+
result.append(computation)
57+
return result
5058

5159
@classmethod
5260
def create_default(cls):
5361
return cls()
5462

63+
5564
def test_models():
5665
model = AlexNet(num_classes=10)
5766
input_data = [1, 2, 3, 4, 5]

tests/test_function_ranker.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import pytest
2+
from pathlib import Path
3+
from unittest.mock import patch
4+
5+
from codeflash.benchmarking.function_ranker import FunctionRanker
6+
from codeflash.discovery.functions_to_optimize import FunctionToOptimize, find_all_functions_in_file
7+
from codeflash.models.models import FunctionParent
8+
9+
10+
@pytest.fixture
11+
def trace_file():
12+
return Path(__file__).parent.parent / "code_to_optimize/code_directories/simple_tracer_e2e/codeflash.trace.sqlite3"
13+
14+
15+
@pytest.fixture
16+
def workload_functions():
17+
workloads_file = Path(__file__).parent.parent / "code_to_optimize/code_directories/simple_tracer_e2e/workload.py"
18+
functions_dict = find_all_functions_in_file(workloads_file)
19+
all_functions = []
20+
for functions_list in functions_dict.values():
21+
all_functions.extend(functions_list)
22+
return all_functions
23+
24+
25+
@pytest.fixture
26+
def function_ranker(trace_file):
27+
return FunctionRanker(trace_file)
28+
29+
30+
def test_function_ranker_initialization(trace_file):
31+
ranker = FunctionRanker(trace_file)
32+
assert ranker.trace_file_path == trace_file
33+
assert ranker._profile_stats is not None
34+
assert isinstance(ranker._function_stats, dict)
35+
36+
37+
def test_load_function_stats(function_ranker):
38+
assert len(function_ranker._function_stats) > 0
39+
40+
# Check that funcA is loaded with expected structure
41+
func_a_key = None
42+
for key, stats in function_ranker._function_stats.items():
43+
if stats["function_name"] == "funcA":
44+
func_a_key = key
45+
break
46+
47+
assert func_a_key is not None
48+
func_a_stats = function_ranker._function_stats[func_a_key]
49+
50+
# Verify funcA stats structure
51+
expected_keys = {
52+
"filename", "function_name", "qualified_name", "class_name",
53+
"line_number", "call_count", "own_time_ns", "cumulative_time_ns",
54+
"time_in_callees_ns", "ttx_score"
55+
}
56+
assert set(func_a_stats.keys()) == expected_keys
57+
58+
# Verify funcA specific values
59+
assert func_a_stats["function_name"] == "funcA"
60+
assert func_a_stats["call_count"] == 1
61+
assert func_a_stats["own_time_ns"] == 27000
62+
assert func_a_stats["cumulative_time_ns"] == 1629000
63+
64+
65+
def test_get_function_ttx_score(function_ranker, workload_functions):
66+
func_a = None
67+
for func in workload_functions:
68+
if func.function_name == "funcA":
69+
func_a = func
70+
break
71+
72+
assert func_a is not None
73+
ttx_score = function_ranker.get_function_ttx_score(func_a)
74+
75+
# Expected ttX score: own_time + (time_in_callees * call_count)
76+
# = 27000 + ((1629000 - 27000) * 1) = 1629000
77+
assert ttx_score == 1629000
78+
79+
80+
def test_rank_functions(function_ranker, workload_functions):
81+
ranked_functions = function_ranker.rank_functions(workload_functions)
82+
83+
assert len(ranked_functions) == len(workload_functions)
84+
85+
# Verify functions are sorted by ttX score in descending order
86+
for i in range(len(ranked_functions) - 1):
87+
current_score = function_ranker.get_function_ttx_score(ranked_functions[i])
88+
next_score = function_ranker.get_function_ttx_score(ranked_functions[i + 1])
89+
assert current_score >= next_score
90+
91+
92+
def test_rerank_and_filter_functions(function_ranker, workload_functions):
93+
filtered_ranked = function_ranker.rerank_and_filter_functions(workload_functions)
94+
95+
# Should filter out functions below importance threshold
96+
assert len(filtered_ranked) <= len(workload_functions)
97+
98+
# funcA should pass the importance threshold (0.33% > 0.1%)
99+
func_a_in_results = any(f.function_name == "funcA" for f in filtered_ranked)
100+
assert func_a_in_results
101+
102+
103+
def test_get_function_stats_summary(function_ranker, workload_functions):
104+
func_a = None
105+
for func in workload_functions:
106+
if func.function_name == "funcA":
107+
func_a = func
108+
break
109+
110+
assert func_a is not None
111+
stats = function_ranker.get_function_stats_summary(func_a)
112+
113+
assert stats is not None
114+
assert stats["function_name"] == "funcA"
115+
assert stats["own_time_ns"] == 27000
116+
assert stats["cumulative_time_ns"] == 1629000
117+
assert stats["ttx_score"] == 1629000
118+
119+
120+
121+
122+
def test_importance_calculation(function_ranker):
123+
total_program_time = sum(
124+
s["own_time_ns"] for s in function_ranker._function_stats.values()
125+
if s.get("own_time_ns", 0) > 0
126+
)
127+
128+
func_a_stats = None
129+
for stats in function_ranker._function_stats.values():
130+
if stats["function_name"] == "funcA":
131+
func_a_stats = stats
132+
break
133+
134+
assert func_a_stats is not None
135+
importance = func_a_stats["own_time_ns"] / total_program_time
136+
137+
# funcA importance should be approximately 0.33% (27000/8242000)
138+
assert abs(importance - 0.00327) < 0.001

0 commit comments

Comments
 (0)