Skip to content

Commit 56e3447

Browse files
committed
reworked matching benchmark key to test results.
1 parent 8d95b18 commit 56e3447

File tree

11 files changed

+144
-136
lines changed

11 files changed

+144
-136
lines changed

codeflash/benchmarking/codeflash_trace.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def setup(self, trace_path: str) -> None:
3333
cur.execute("PRAGMA synchronous = OFF")
3434
cur.execute(
3535
"CREATE TABLE IF NOT EXISTS benchmark_function_timings("
36-
"function_name TEXT, class_name TEXT, module_name TEXT, file_name TEXT,"
37-
"benchmark_function_name TEXT, benchmark_file_name TEXT, benchmark_line_number INTEGER,"
36+
"function_name TEXT, class_name TEXT, module_name TEXT, file_path TEXT,"
37+
"benchmark_function_name TEXT, benchmark_file_path TEXT, benchmark_line_number INTEGER,"
3838
"function_time_ns INTEGER, overhead_time_ns INTEGER, args BLOB, kwargs BLOB)"
3939
)
4040
self._connection.commit()
@@ -62,8 +62,8 @@ def write_function_timings(self) -> None:
6262
# Insert data into the benchmark_function_timings table
6363
cur.executemany(
6464
"INSERT INTO benchmark_function_timings"
65-
"(function_name, class_name, module_name, file_name, benchmark_function_name, "
66-
"benchmark_file_name, benchmark_line_number, function_time_ns, overhead_time_ns, args, kwargs) "
65+
"(function_name, class_name, module_name, file_path, benchmark_function_name, "
66+
"benchmark_file_path, benchmark_line_number, function_time_ns, overhead_time_ns, args, kwargs) "
6767
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
6868
self.function_calls_data
6969
)
@@ -115,7 +115,7 @@ def wrapper(*args, **kwargs):
115115

116116
# Get benchmark info from environment
117117
benchmark_function_name = os.environ.get("CODEFLASH_BENCHMARK_FUNCTION_NAME", "")
118-
benchmark_file_name = os.environ.get("CODEFLASH_BENCHMARK_FILE_NAME", "")
118+
benchmark_file_path = os.environ.get("CODEFLASH_BENCHMARK_FILE_PATH", "")
119119
benchmark_line_number = os.environ.get("CODEFLASH_BENCHMARK_LINE_NUMBER", "")
120120
# Get class name
121121
class_name = ""
@@ -151,7 +151,7 @@ def wrapper(*args, **kwargs):
151151

152152
self.function_calls_data.append(
153153
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
154-
benchmark_function_name, benchmark_file_name, benchmark_line_number, execution_time,
154+
benchmark_function_name, benchmark_file_path, benchmark_line_number, execution_time,
155155
overhead_time, pickled_args, pickled_kwargs)
156156
)
157157
return result

codeflash/benchmarking/plugin/plugin.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def setup(self, trace_path:str) -> None:
2424
cur.execute("PRAGMA synchronous = OFF")
2525
cur.execute(
2626
"CREATE TABLE IF NOT EXISTS benchmark_timings("
27-
"benchmark_file_name TEXT, benchmark_function_name TEXT, benchmark_line_number INTEGER,"
27+
"benchmark_file_path TEXT, benchmark_function_name TEXT, benchmark_line_number INTEGER,"
2828
"benchmark_time_ns INTEGER)"
2929
)
3030
self._connection.commit()
@@ -47,7 +47,7 @@ def write_benchmark_timings(self) -> None:
4747
cur = self._connection.cursor()
4848
# Insert data into the benchmark_timings table
4949
cur.executemany(
50-
"INSERT INTO benchmark_timings (benchmark_file_name, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)",
50+
"INSERT INTO benchmark_timings (benchmark_file_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)",
5151
self.benchmark_timings
5252
)
5353
self._connection.commit()
@@ -86,7 +86,7 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
8686
# Query the function_calls table for all function calls
8787
cursor.execute(
8888
"SELECT module_name, class_name, function_name, "
89-
"benchmark_file_name, benchmark_function_name, benchmark_line_number, function_time_ns "
89+
"benchmark_file_path, benchmark_function_name, benchmark_line_number, function_time_ns "
9090
"FROM benchmark_function_timings"
9191
)
9292

@@ -101,7 +101,7 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
101101
qualified_name = f"{module_name}.{function_name}"
102102

103103
# Create the benchmark key (file::function::line)
104-
benchmark_key = BenchmarkKey(file_name=benchmark_file, function_name=benchmark_func)
104+
benchmark_key = BenchmarkKey(file_path=benchmark_file, function_name=benchmark_func)
105105
# Initialize the inner dictionary if needed
106106
if qualified_name not in result:
107107
result[qualified_name] = {}
@@ -143,20 +143,20 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
143143
try:
144144
# Query the benchmark_function_timings table to get total overhead for each benchmark
145145
cursor.execute(
146-
"SELECT benchmark_file_name, benchmark_function_name, benchmark_line_number, SUM(overhead_time_ns) "
146+
"SELECT benchmark_file_path, benchmark_function_name, benchmark_line_number, SUM(overhead_time_ns) "
147147
"FROM benchmark_function_timings "
148-
"GROUP BY benchmark_file_name, benchmark_function_name, benchmark_line_number"
148+
"GROUP BY benchmark_file_path, benchmark_function_name, benchmark_line_number"
149149
)
150150

151151
# Process overhead information
152152
for row in cursor.fetchall():
153153
benchmark_file, benchmark_func, benchmark_line, total_overhead_ns = row
154-
benchmark_key = BenchmarkKey(file_name=benchmark_file, function_name=benchmark_func)
154+
benchmark_key = BenchmarkKey(file_path=benchmark_file, function_name=benchmark_func)
155155
overhead_by_benchmark[benchmark_key] = total_overhead_ns or 0 # Handle NULL sum case
156156

157157
# Query the benchmark_timings table for total times
158158
cursor.execute(
159-
"SELECT benchmark_file_name, benchmark_function_name, benchmark_line_number, benchmark_time_ns "
159+
"SELECT benchmark_file_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns "
160160
"FROM benchmark_timings"
161161
)
162162

@@ -165,7 +165,7 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
165165
benchmark_file, benchmark_func, benchmark_line, time_ns = row
166166

167167
# Create the benchmark key (file::function::line)
168-
benchmark_key = BenchmarkKey(file_name=benchmark_file, function_name=benchmark_func)
168+
benchmark_key = BenchmarkKey(file_path=benchmark_file, function_name=benchmark_func)
169169
# Subtract overhead from total time
170170
overhead = overhead_by_benchmark.get(benchmark_key, 0)
171171
result[benchmark_key] = time_ns - overhead
@@ -236,13 +236,13 @@ def test_something(benchmark):
236236
The return value of the function
237237
238238
"""
239-
benchmark_file_name = self.request.node.fspath
239+
benchmark_file_path = str(self.request.node.fspath)
240240
benchmark_function_name = self.request.node.name
241241
line_number = int(str(sys._getframe(1).f_lineno)) # 1 frame up in the call stack
242242

243243
# Set env vars so codeflash decorator can identify what benchmark its being run in
244244
os.environ["CODEFLASH_BENCHMARK_FUNCTION_NAME"] = benchmark_function_name
245-
os.environ["CODEFLASH_BENCHMARK_FILE_NAME"] = benchmark_file_name
245+
os.environ["CODEFLASH_BENCHMARK_FILE_PATH"] = benchmark_file_path
246246
os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number)
247247
os.environ["CODEFLASH_BENCHMARKING"] = "True"
248248

@@ -260,7 +260,7 @@ def test_something(benchmark):
260260
codeflash_trace.function_call_count = 0
261261
# Add to the benchmark timings buffer
262262
codeflash_benchmark_plugin.benchmark_timings.append(
263-
(benchmark_file_name, benchmark_function_name, line_number, end - start))
263+
(benchmark_file_path, benchmark_function_name, line_number, end - start))
264264

265265
return result
266266

codeflash/benchmarking/replay_test.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,35 @@
22

33
import sqlite3
44
import textwrap
5-
from collections.abc import Generator
6-
from typing import Any, Dict
5+
from pathlib import Path
6+
from typing import TYPE_CHECKING, Any
77

88
import isort
99

1010
from codeflash.cli_cmds.console import logger
1111
from codeflash.discovery.functions_to_optimize import inspect_top_level_functions_or_methods
1212
from codeflash.verification.verification_utils import get_test_file_path
13-
from pathlib import Path
13+
14+
if TYPE_CHECKING:
15+
from collections.abc import Generator
16+
1417

1518
def get_next_arg_and_return(
16-
trace_file: str, function_name: str, file_name: str, class_name: str | None = None, num_to_get: int = 256
19+
trace_file: str, function_name: str, file_path: str, class_name: str | None = None, num_to_get: int = 256
1720
) -> Generator[Any]:
1821
db = sqlite3.connect(trace_file)
1922
cur = db.cursor()
2023
limit = num_to_get
2124

2225
if class_name is not None:
2326
cursor = cur.execute(
24-
"SELECT * FROM benchmark_function_timings WHERE function_name = ? AND file_name = ? AND class_name = ? LIMIT ?",
25-
(function_name, file_name, class_name, limit),
27+
"SELECT * FROM benchmark_function_timings WHERE function_name = ? AND file_path = ? AND class_name = ? LIMIT ?",
28+
(function_name, file_path, class_name, limit),
2629
)
2730
else:
2831
cursor = cur.execute(
29-
"SELECT * FROM benchmark_function_timings WHERE function_name = ? AND file_name = ? AND class_name = '' LIMIT ?",
30-
(function_name, file_name, limit),
32+
"SELECT * FROM benchmark_function_timings WHERE function_name = ? AND file_path = ? AND class_name = '' LIMIT ?",
33+
(function_name, file_path, limit),
3134
)
3235

3336
while (val := cursor.fetchone()) is not None:
@@ -88,7 +91,7 @@ def create_trace_replay_test_code(
8891
# Templates for different types of tests
8992
test_function_body = textwrap.dedent(
9093
"""\
91-
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_name=r"{file_name}", num_to_get={max_run_count}):
94+
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_path=r"{file_path}", num_to_get={max_run_count}):
9295
args = pickle.loads(args_pkl)
9396
kwargs = pickle.loads(kwargs_pkl)
9497
ret = {function_name}(*args, **kwargs)
@@ -97,7 +100,7 @@ def create_trace_replay_test_code(
97100

98101
test_method_body = textwrap.dedent(
99102
"""\
100-
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_name=r"{file_name}", class_name="{class_name}", num_to_get={max_run_count}):
103+
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
101104
args = pickle.loads(args_pkl)
102105
kwargs = pickle.loads(kwargs_pkl){filter_variables}
103106
function_name = "{orig_function_name}"
@@ -112,7 +115,7 @@ def create_trace_replay_test_code(
112115

113116
test_class_method_body = textwrap.dedent(
114117
"""\
115-
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_name=r"{file_name}", class_name="{class_name}", num_to_get={max_run_count}):
118+
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
116119
args = pickle.loads(args_pkl)
117120
kwargs = pickle.loads(kwargs_pkl){filter_variables}
118121
if not args:
@@ -122,7 +125,7 @@ def create_trace_replay_test_code(
122125
)
123126
test_static_method_body = textwrap.dedent(
124127
"""\
125-
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_name=r"{file_name}", class_name="{class_name}", num_to_get={max_run_count}):
128+
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
126129
args = pickle.loads(args_pkl)
127130
kwargs = pickle.loads(kwargs_pkl){filter_variables}
128131
ret = {class_name_alias}{method_name}(*args, **kwargs)
@@ -140,13 +143,13 @@ def create_trace_replay_test_code(
140143
module_name = func.get("module_name")
141144
function_name = func.get("function_name")
142145
class_name = func.get("class_name")
143-
file_name = func.get("file_name")
146+
file_path = func.get("file_path")
144147
function_properties = func.get("function_properties")
145148
if not class_name:
146149
alias = get_function_alias(module_name, function_name)
147150
test_body = test_function_body.format(
148151
function_name=alias,
149-
file_name=file_name,
152+
file_path=file_path,
150153
orig_function_name=function_name,
151154
max_run_count=max_run_count,
152155
)
@@ -160,7 +163,7 @@ def create_trace_replay_test_code(
160163
if function_properties.is_classmethod:
161164
test_body = test_class_method_body.format(
162165
orig_function_name=function_name,
163-
file_name=file_name,
166+
file_path=file_path,
164167
class_name_alias=class_name_alias,
165168
class_name=class_name,
166169
method_name=method_name,
@@ -170,7 +173,7 @@ def create_trace_replay_test_code(
170173
elif function_properties.is_staticmethod:
171174
test_body = test_static_method_body.format(
172175
orig_function_name=function_name,
173-
file_name=file_name,
176+
file_path=file_path,
174177
class_name_alias=class_name_alias,
175178
class_name=class_name,
176179
method_name=method_name,
@@ -180,7 +183,7 @@ def create_trace_replay_test_code(
180183
else:
181184
test_body = test_method_body.format(
182185
orig_function_name=function_name,
183-
file_name=file_name,
186+
file_path=file_path,
184187
class_name_alias=class_name_alias,
185188
class_name=class_name,
186189
method_name=method_name,
@@ -216,42 +219,41 @@ def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework
216219

217220
# Get distinct benchmark names
218221
cursor.execute(
219-
"SELECT DISTINCT benchmark_function_name, benchmark_file_name FROM benchmark_function_timings"
222+
"SELECT DISTINCT benchmark_function_name, benchmark_file_path FROM benchmark_function_timings"
220223
)
221224
benchmarks = cursor.fetchall()
222225

223226
# Generate a test for each benchmark
224227
for benchmark in benchmarks:
225-
benchmark_function_name, benchmark_file_name = benchmark
228+
benchmark_function_name, benchmark_file_path = benchmark
226229
# Get functions associated with this benchmark
227230
cursor.execute(
228-
"SELECT DISTINCT function_name, class_name, module_name, file_name, benchmark_line_number FROM benchmark_function_timings "
229-
"WHERE benchmark_function_name = ? AND benchmark_file_name = ?",
230-
(benchmark_function_name, benchmark_file_name)
231+
"SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_line_number FROM benchmark_function_timings "
232+
"WHERE benchmark_function_name = ? AND benchmark_file_path = ?",
233+
(benchmark_function_name, benchmark_file_path)
231234
)
232235

233236
functions_data = []
234237
for func_row in cursor.fetchall():
235-
function_name, class_name, module_name, file_name, benchmark_line_number = func_row
236-
238+
function_name, class_name, module_name, file_path, benchmark_line_number = func_row
237239
# Add this function to our list
238240
functions_data.append({
239241
"function_name": function_name,
240242
"class_name": class_name,
241-
"file_name": file_name,
243+
"file_path": file_path,
242244
"module_name": module_name,
243245
"benchmark_function_name": benchmark_function_name,
244-
"benchmark_file_name": benchmark_file_name,
246+
"benchmark_file_path": benchmark_file_path,
245247
"benchmark_line_number": benchmark_line_number,
246248
"function_properties": inspect_top_level_functions_or_methods(
247-
file_name=file_name,
249+
file_name=Path(file_path),
248250
function_or_method_name=function_name,
249251
class_name=class_name,
250252
)
251253
})
252254

253255
if not functions_data:
254-
logger.info(f"No functions found for benchmark {benchmark_function_name} in {benchmark_file_name}")
256+
logger.info(f"No functions found for benchmark {benchmark_function_name} in {benchmark_file_path}")
255257
continue
256258

257259
# Generate the test code for this benchmark
@@ -265,17 +267,19 @@ def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework
265267

266268
# Write to file if requested
267269
if output_dir:
270+
name = Path(benchmark_file_path).name.split(".")[0][5:] # remove "test_" from the name since we add it in later
268271
output_file = get_test_file_path(
269-
test_dir=Path(output_dir), function_name=f"{benchmark_file_name}_{benchmark_function_name}", test_type="replay"
272+
test_dir=Path(output_dir), function_name=f"{name}_{benchmark_function_name}", test_type="replay"
270273
)
271274
# Write test code to file, parents = true
272275
output_dir.mkdir(parents=True, exist_ok=True)
273276
output_file.write_text(test_code, "utf-8")
274277
count += 1
275-
logger.info(f"Replay test for benchmark `{benchmark_function_name}` in {benchmark_file_name} written to {output_file}")
278+
logger.info(f"Replay test for benchmark `{benchmark_function_name}` in {name} written to {output_file}")
276279

277280
conn.close()
278281

279282
except Exception as e:
280283
logger.info(f"Error generating replay tests: {e}")
284+
281285
return count

0 commit comments

Comments
 (0)