Skip to content

Commit bb9c5db

Browse files
committed
benchmark flow is working. changed paths to use module_path instead of file_path for Benchmarkkey
1 parent cf00212 commit bb9c5db

File tree

8 files changed

+117
-100
lines changed

8 files changed

+117
-100
lines changed

codeflash/benchmarking/codeflash_trace.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import functools
22
import os
3+
import pickle
34
import sqlite3
45
import sys
6+
import time
7+
from typing import Callable
58

6-
import pickle
79
import dill
810

9-
import time
10-
from typing import Callable, Optional
1111

1212
class CodeflashTrace:
1313
"""Decorator class that traces and profiles function execution."""
@@ -35,7 +35,7 @@ def setup(self, trace_path: str) -> None:
3535
cur.execute(
3636
"CREATE TABLE IF NOT EXISTS benchmark_function_timings("
3737
"function_name TEXT, class_name TEXT, module_name TEXT, file_path TEXT,"
38-
"benchmark_function_name TEXT, benchmark_file_path TEXT, benchmark_line_number INTEGER,"
38+
"benchmark_function_name TEXT, benchmark_module_path TEXT, benchmark_line_number INTEGER,"
3939
"function_time_ns INTEGER, overhead_time_ns INTEGER, args BLOB, kwargs BLOB)"
4040
)
4141
self._connection.commit()
@@ -51,6 +51,7 @@ def write_function_timings(self) -> None:
5151
5252
Args:
5353
data: List of function call data tuples to write
54+
5455
"""
5556
if not self.function_calls_data:
5657
return # No data to write
@@ -64,7 +65,7 @@ def write_function_timings(self) -> None:
6465
cur.executemany(
6566
"INSERT INTO benchmark_function_timings"
6667
"(function_name, class_name, module_name, file_path, benchmark_function_name, "
67-
"benchmark_file_path, benchmark_line_number, function_time_ns, overhead_time_ns, args, kwargs) "
68+
"benchmark_module_path, benchmark_line_number, function_time_ns, overhead_time_ns, args, kwargs) "
6869
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
6970
self.function_calls_data
7071
)
@@ -116,7 +117,7 @@ def wrapper(*args, **kwargs):
116117

117118
# Get benchmark info from environment
118119
benchmark_function_name = os.environ.get("CODEFLASH_BENCHMARK_FUNCTION_NAME", "")
119-
benchmark_file_path = os.environ.get("CODEFLASH_BENCHMARK_FILE_PATH", "")
120+
benchmark_module_path = os.environ.get("CODEFLASH_BENCHMARK_MODULE_PATH", "")
120121
benchmark_line_number = os.environ.get("CODEFLASH_BENCHMARK_LINE_NUMBER", "")
121122
# Get class name
122123
class_name = ""
@@ -143,7 +144,7 @@ def wrapper(*args, **kwargs):
143144

144145
except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError) as e:
145146
print(f"Error pickling arguments for function {func.__name__}: {e}")
146-
return
147+
return None
147148

148149
if len(self.function_calls_data) > 1000:
149150
self.write_function_timings()
@@ -152,7 +153,7 @@ def wrapper(*args, **kwargs):
152153

153154
self.function_calls_data.append(
154155
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
155-
benchmark_function_name, benchmark_file_path, benchmark_line_number, execution_time,
156+
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
156157
overhead_time, pickled_args, pickled_kwargs)
157158
)
158159
return result

codeflash/benchmarking/plugin/plugin.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def setup(self, trace_path:str, project_root:str) -> None:
3131
cur.execute("PRAGMA journal_mode = MEMORY")
3232
cur.execute(
3333
"CREATE TABLE IF NOT EXISTS benchmark_timings("
34-
"benchmark_file_path TEXT, benchmark_function_name TEXT, benchmark_line_number INTEGER,"
34+
"benchmark_module_path TEXT, benchmark_function_name TEXT, benchmark_line_number INTEGER,"
3535
"benchmark_time_ns INTEGER)"
3636
)
3737
self._connection.commit()
@@ -54,7 +54,7 @@ def write_benchmark_timings(self) -> None:
5454
cur = self._connection.cursor()
5555
# Insert data into the benchmark_timings table
5656
cur.executemany(
57-
"INSERT INTO benchmark_timings (benchmark_file_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)",
57+
"INSERT INTO benchmark_timings (benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)",
5858
self.benchmark_timings
5959
)
6060
self._connection.commit()
@@ -93,7 +93,7 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
9393
# Query the function_calls table for all function calls
9494
cursor.execute(
9595
"SELECT module_name, class_name, function_name, "
96-
"benchmark_file_path, benchmark_function_name, benchmark_line_number, function_time_ns "
96+
"benchmark_module_path, benchmark_function_name, benchmark_line_number, function_time_ns "
9797
"FROM benchmark_function_timings"
9898
)
9999

@@ -108,7 +108,7 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
108108
qualified_name = f"{module_name}.{function_name}"
109109

110110
# Create the benchmark key (file::function::line)
111-
benchmark_key = BenchmarkKey(file_path=benchmark_file, function_name=benchmark_func)
111+
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
112112
# Initialize the inner dictionary if needed
113113
if qualified_name not in result:
114114
result[qualified_name] = {}
@@ -150,20 +150,20 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
150150
try:
151151
# Query the benchmark_function_timings table to get total overhead for each benchmark
152152
cursor.execute(
153-
"SELECT benchmark_file_path, benchmark_function_name, benchmark_line_number, SUM(overhead_time_ns) "
153+
"SELECT benchmark_module_path, benchmark_function_name, benchmark_line_number, SUM(overhead_time_ns) "
154154
"FROM benchmark_function_timings "
155-
"GROUP BY benchmark_file_path, benchmark_function_name, benchmark_line_number"
155+
"GROUP BY benchmark_module_path, benchmark_function_name, benchmark_line_number"
156156
)
157157

158158
# Process overhead information
159159
for row in cursor.fetchall():
160160
benchmark_file, benchmark_func, benchmark_line, total_overhead_ns = row
161-
benchmark_key = BenchmarkKey(file_path=benchmark_file, function_name=benchmark_func)
161+
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
162162
overhead_by_benchmark[benchmark_key] = total_overhead_ns or 0 # Handle NULL sum case
163163

164164
# Query the benchmark_timings table for total times
165165
cursor.execute(
166-
"SELECT benchmark_file_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns "
166+
"SELECT benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns "
167167
"FROM benchmark_timings"
168168
)
169169

@@ -172,7 +172,7 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
172172
benchmark_file, benchmark_func, benchmark_line, time_ns = row
173173

174174
# Create the benchmark key (file::function::line)
175-
benchmark_key = BenchmarkKey(file_path=benchmark_file, function_name=benchmark_func)
175+
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
176176
# Subtract overhead from total time
177177
overhead = overhead_by_benchmark.get(benchmark_key, 0)
178178
result[benchmark_key] = time_ns - overhead
@@ -244,13 +244,13 @@ def test_something(benchmark):
244244
a
245245
246246
"""
247-
benchmark_file_path = module_name_from_file_path(Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root))
247+
benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root))
248248
benchmark_function_name = self.request.node.name
249249
line_number = int(str(sys._getframe(1).f_lineno)) # 1 frame up in the call stack
250250

251251
# Set env vars so codeflash decorator can identify what benchmark its being run in
252252
os.environ["CODEFLASH_BENCHMARK_FUNCTION_NAME"] = benchmark_function_name
253-
os.environ["CODEFLASH_BENCHMARK_FILE_PATH"] = benchmark_file_path
253+
os.environ["CODEFLASH_BENCHMARK_MODULE_PATH"] = benchmark_module_path
254254
os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number)
255255
os.environ["CODEFLASH_BENCHMARKING"] = "True"
256256

@@ -268,7 +268,7 @@ def test_something(benchmark):
268268
codeflash_trace.function_call_count = 0
269269
# Add to the benchmark timings buffer
270270
codeflash_benchmark_plugin.benchmark_timings.append(
271-
(benchmark_file_path, benchmark_function_name, line_number, end - start))
271+
(benchmark_module_path, benchmark_function_name, line_number, end - start))
272272

273273
return result
274274

codeflash/benchmarking/replay_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -227,18 +227,18 @@ def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework
227227

228228
# Get distinct benchmark file paths
229229
cursor.execute(
230-
"SELECT DISTINCT benchmark_file_path FROM benchmark_function_timings"
230+
"SELECT DISTINCT benchmark_module_path FROM benchmark_function_timings"
231231
)
232232
benchmark_files = cursor.fetchall()
233233

234234
# Generate a test for each benchmark file
235235
for benchmark_file in benchmark_files:
236-
benchmark_file_path = benchmark_file[0]
236+
benchmark_module_path = benchmark_file[0]
237237
# Get all benchmarks and functions associated with this file path
238238
cursor.execute(
239239
"SELECT DISTINCT benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number FROM benchmark_function_timings "
240-
"WHERE benchmark_file_path = ?",
241-
(benchmark_file_path,)
240+
"WHERE benchmark_module_path = ?",
241+
(benchmark_module_path,)
242242
)
243243

244244
functions_data = []
@@ -251,7 +251,7 @@ def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework
251251
"file_path": file_path,
252252
"module_name": module_name,
253253
"benchmark_function_name": benchmark_function_name,
254-
"benchmark_file_path": benchmark_file_path,
254+
"benchmark_module_path": benchmark_module_path,
255255
"benchmark_line_number": benchmark_line_number,
256256
"function_properties": inspect_top_level_functions_or_methods(
257257
file_name=Path(file_path),
@@ -261,7 +261,7 @@ def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework
261261
})
262262

263263
if not functions_data:
264-
logger.info(f"No benchmark test functions found in {benchmark_file_path}")
264+
logger.info(f"No benchmark test functions found in {benchmark_module_path}")
265265
continue
266266
# Generate the test code for this benchmark
267267
test_code = create_trace_replay_test_code(
@@ -272,7 +272,7 @@ def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework
272272
)
273273
test_code = isort.code(test_code)
274274
output_file = get_test_file_path(
275-
test_dir=Path(output_dir), function_name=benchmark_file_path, test_type="replay"
275+
test_dir=Path(output_dir), function_name=benchmark_module_path, test_type="replay"
276276
)
277277
# Write test code to file, parents = true
278278
output_dir.mkdir(parents=True, exist_ok=True)

codeflash/benchmarking/utils.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import shutil
34
from typing import Optional
45

56
from rich.console import Console
@@ -35,32 +36,45 @@ def validate_and_format_benchmark_table(function_benchmark_timings: dict[str, di
3536
function_to_result[func_path] = sorted_tests
3637
return function_to_result
3738

39+
3840
def print_benchmark_table(function_to_results: dict[str, list[tuple[BenchmarkKey, float, float, float]]]) -> None:
39-
console = Console()
41+
42+
try:
43+
terminal_width = int(shutil.get_terminal_size().columns * 0.8)
44+
except Exception:
45+
terminal_width = 200 # Fallback width
46+
console = Console(width = terminal_width)
4047
for func_path, sorted_tests in function_to_results.items():
4148
console.print()
4249
function_name = func_path.split(":")[-1]
4350

4451
# Create a table for this function
4552
table = Table(title=f"Function: {function_name}", border_style="blue")
4653

47-
# Add columns
48-
table.add_column("Benchmark Test", style="cyan", no_wrap=True)
54+
# Add columns - split the benchmark test into two columns
55+
table.add_column("Benchmark Module Path", style="cyan", no_wrap=True)
56+
table.add_column("Test Function", style="magenta", no_wrap=True)
4957
table.add_column("Total Time (ms)", justify="right", style="green")
5058
table.add_column("Function Time (ms)", justify="right", style="yellow")
5159
table.add_column("Percentage (%)", justify="right", style="red")
5260

5361
for benchmark_key, total_time, func_time, percentage in sorted_tests:
62+
# Split the benchmark test into module path and function name
63+
module_path = benchmark_key.module_path
64+
test_function = benchmark_key.function_name
65+
5466
if total_time == 0.0:
5567
table.add_row(
56-
f"{benchmark_key.file_path}::{benchmark_key.function_name}",
68+
module_path,
69+
test_function,
5770
"N/A",
5871
"N/A",
5972
"N/A"
6073
)
6174
else:
6275
table.add_row(
63-
f"{benchmark_key.file_path}::{benchmark_key.function_name}",
76+
module_path,
77+
test_function,
6478
f"{total_time:.3f}",
6579
f"{func_time:.3f}",
6680
f"{percentage:.2f}"
@@ -108,7 +122,7 @@ def process_benchmark_data(
108122

109123
benchmark_details.append(
110124
BenchmarkDetail(
111-
benchmark_name=benchmark_key.file_path,
125+
benchmark_name=benchmark_key.module_path,
112126
test_function=benchmark_key.function_name,
113127
original_timing=humanize_runtime(int(total_benchmark_timing)),
114128
expected_new_timing=humanize_runtime(int(expected_new_benchmark_timing)),

codeflash/models/models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,11 @@ class BestOptimization(BaseModel):
8484

8585
@dataclass(frozen=True)
8686
class BenchmarkKey:
87-
file_path: str
87+
module_path: str
8888
function_name: str
8989

9090
def __str__(self) -> str:
91-
return f"{self.file_path}::{self.function_name}"
91+
return f"{self.module_path}::{self.function_name}"
9292

9393
@dataclass
9494
class BenchmarkDetail:
@@ -484,7 +484,7 @@ def group_by_benchmarks(self, benchmark_keys:list[BenchmarkKey], benchmark_repla
484484
test_results_by_benchmark = defaultdict(TestResults)
485485
benchmark_module_path = {}
486486
for benchmark_key in benchmark_keys:
487-
benchmark_module_path[benchmark_key] = module_name_from_file_path(benchmark_replay_test_dir.resolve() / f"test_{Path(benchmark_key.file_path).name.split('.')[0][5:]}__replay_test_", project_root)
487+
benchmark_module_path[benchmark_key] = module_name_from_file_path(benchmark_replay_test_dir.resolve() / f"test_{benchmark_key.module_path.replace(".", "_")}__replay_test_", project_root)
488488
for test_result in self.test_results:
489489
if (test_result.test_type == TestType.REPLAY_TEST):
490490
for benchmark_key, module_path in benchmark_module_path.items():

0 commit comments

Comments
 (0)