Skip to content

Commit 6180c9d

Browse files
committed
refactored get_function_benchmark_timings and get_benchmark_timings into BenchmarkDatabaseUtils class
1 parent 684acf8 commit 6180c9d

File tree

4 files changed

+123
-130
lines changed

4 files changed

+123
-130
lines changed

codeflash/benchmarking/benchmark_database_utils.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,120 @@ def close(self) -> None:
177177
self.connection.close()
178178
self.connection = None
179179

180+
181+
@staticmethod
182+
def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[str, int]]:
183+
"""Process the trace file and extract timing data for all functions.
184+
185+
Args:
186+
trace_path: Path to the trace file
187+
188+
Returns:
189+
A nested dictionary where:
190+
- Outer keys are module_name.qualified_name (module.class.function)
191+
- Inner keys are benchmark filename :: benchmark test function :: line number
192+
- Values are function timing in milliseconds
193+
194+
"""
195+
# Initialize the result dictionary
196+
result = {}
197+
198+
# Connect to the SQLite database
199+
connection = sqlite3.connect(trace_path)
200+
cursor = connection.cursor()
201+
202+
try:
203+
# Query the function_calls table for all function calls
204+
cursor.execute(
205+
"SELECT module_name, class_name, function_name, "
206+
"benchmark_file_name, benchmark_function_name, benchmark_line_number, time_ns "
207+
"FROM function_calls"
208+
)
209+
210+
# Process each row
211+
for row in cursor.fetchall():
212+
module_name, class_name, function_name, benchmark_file, benchmark_func, benchmark_line, time_ns = row
213+
214+
# Create the function key (module_name.class_name.function_name)
215+
if class_name:
216+
qualified_name = f"{module_name}.{class_name}.{function_name}"
217+
else:
218+
qualified_name = f"{module_name}.{function_name}"
219+
220+
# Create the benchmark key (file::function::line)
221+
benchmark_key = f"{benchmark_file}::{benchmark_func}::{benchmark_line}"
222+
223+
# Initialize the inner dictionary if needed
224+
if qualified_name not in result:
225+
result[qualified_name] = {}
226+
227+
# If multiple calls to the same function in the same benchmark,
228+
# add the times together
229+
if benchmark_key in result[qualified_name]:
230+
result[qualified_name][benchmark_key] += time_ns
231+
else:
232+
result[qualified_name][benchmark_key] = time_ns
233+
234+
finally:
235+
# Close the connection
236+
connection.close()
237+
238+
return result
239+
240+
@staticmethod
241+
def get_benchmark_timings(trace_path: Path) -> dict[str, int]:
242+
"""Extract total benchmark timings from trace files.
243+
244+
Args:
245+
trace_path: Path to the trace file
246+
247+
Returns:
248+
A dictionary mapping where:
249+
- Keys are benchmark filename :: benchmark test function :: line number
250+
- Values are total benchmark timing in milliseconds (with overhead subtracted)
251+
252+
"""
253+
# Initialize the result dictionary
254+
result = {}
255+
overhead_by_benchmark = {}
256+
257+
# Connect to the SQLite database
258+
connection = sqlite3.connect(trace_path)
259+
cursor = connection.cursor()
260+
261+
try:
262+
# Query the function_calls table to get total overhead for each benchmark
263+
cursor.execute(
264+
"SELECT benchmark_file_name, benchmark_function_name, benchmark_line_number, SUM(overhead_time_ns) "
265+
"FROM function_calls "
266+
"GROUP BY benchmark_file_name, benchmark_function_name, benchmark_line_number"
267+
)
268+
269+
# Process overhead information
270+
for row in cursor.fetchall():
271+
benchmark_file, benchmark_func, benchmark_line, total_overhead_ns = row
272+
benchmark_key = f"{benchmark_file}::{benchmark_func}::{benchmark_line}"
273+
overhead_by_benchmark[benchmark_key] = total_overhead_ns or 0 # Handle NULL sum case
274+
275+
# Query the benchmark_timings table for total times
276+
cursor.execute(
277+
"SELECT benchmark_file_name, benchmark_function_name, benchmark_line_number, time_ns "
278+
"FROM benchmark_timings"
279+
)
280+
281+
# Process each row and subtract overhead
282+
for row in cursor.fetchall():
283+
benchmark_file, benchmark_func, benchmark_line, time_ns = row
284+
285+
# Create the benchmark key (file::function::line)
286+
benchmark_key = f"{benchmark_file}::{benchmark_func}::{benchmark_line}"
287+
288+
# Subtract overhead from total time
289+
overhead = overhead_by_benchmark.get(benchmark_key, 0)
290+
result[benchmark_key] = time_ns - overhead
291+
292+
finally:
293+
# Close the connection
294+
connection.close()
295+
296+
return result

codeflash/benchmarking/get_trace_info.py

Lines changed: 0 additions & 121 deletions
This file was deleted.

codeflash/optimization/optimizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import TYPE_CHECKING
99

1010
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
11+
from codeflash.benchmarking.benchmark_database_utils import BenchmarkDatabaseUtils
1112
from codeflash.benchmarking.replay_test import generate_replay_test
1213
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
1314
from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table
@@ -24,7 +25,6 @@
2425
from codeflash.telemetry.posthog_cf import ph
2526
from codeflash.verification.test_results import TestType
2627
from codeflash.verification.verification_utils import TestConfig
27-
from codeflash.benchmarking.get_trace_info import get_function_benchmark_timings, get_benchmark_timings
2828
from codeflash.benchmarking.utils import print_benchmark_table
2929
from codeflash.benchmarking.instrument_codeflash_trace import instrument_codeflash_trace_decorator
3030

@@ -119,8 +119,8 @@ def run(self) -> None:
119119
if replay_count == 0:
120120
logger.info(f"No valid benchmarks found in {self.args.benchmarks_root} for functions to optimize, continuing optimization")
121121
else:
122-
function_benchmark_timings = get_function_benchmark_timings(trace_file)
123-
total_benchmark_timings = get_benchmark_timings(trace_file)
122+
function_benchmark_timings = BenchmarkDatabaseUtils.get_function_benchmark_timings(trace_file)
123+
total_benchmark_timings = BenchmarkDatabaseUtils.get_benchmark_timings(trace_file)
124124
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
125125
print_benchmark_table(function_to_results)
126126
logger.info("Finished tracing existing benchmarks")

tests/test_trace_benchmarks.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import sqlite3
22

3-
from codeflash.benchmarking.codeflash_trace import codeflash_trace
4-
from codeflash.benchmarking.get_trace_info import get_function_benchmark_timings, get_benchmark_timings
3+
from codeflash.benchmarking.benchmark_database_utils import BenchmarkDatabaseUtils
54
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
65
from codeflash.benchmarking.replay_test import generate_replay_test
76
from pathlib import Path
87

98
from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table
10-
from codeflash.code_utils.code_utils import get_run_tmp_file
119
import shutil
1210

1311

@@ -180,9 +178,8 @@ def test_trace_multithreaded_benchmark() -> None:
180178

181179
# Assert the length of function calls
182180
assert len(function_calls) == 10, f"Expected 10 function calls, but got {len(function_calls)}"
183-
function_benchmark_timings = get_function_benchmark_timings(output_file)
184-
total_benchmark_timings = get_benchmark_timings(output_file)
185-
# This will throw an error if summed function timings exceed total benchmark timing
181+
function_benchmark_timings = BenchmarkDatabaseUtils.get_function_benchmark_timings(output_file)
182+
total_benchmark_timings = BenchmarkDatabaseUtils.get_benchmark_timings(output_file)
186183
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
187184
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
188185

0 commit comments

Comments
 (0)