Skip to content

Commit d703b13

Browse files
committed
replay tests are now grouped by benchmark file. each benchmark test file will create one replay test file.
1 parent 9764c25 commit d703b13

File tree

4 files changed

+86
-58
lines changed

4 files changed

+86
-58
lines changed

code_to_optimize/tests/pytest/benchmarks_test/test_benchmark_bubble_sort.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ def test_sort2():
1515
def test_class_sort(benchmark):
1616
obj = Sorter(list(reversed(range(100))))
1717
result1 = benchmark(obj.sorter, 2)
18+
19+
def test_class_sort2(benchmark):
1820
result2 = benchmark(Sorter.sort_class, list(reversed(range(100))))
21+
22+
def test_class_sort3(benchmark):
1923
result3 = benchmark(Sorter.sort_static, list(reversed(range(100))))
24+
25+
def test_class_sort4(benchmark):
2026
result4 = benchmark(Sorter, [1,2,3])

codeflash/benchmarking/replay_test.py

Lines changed: 47 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,21 @@
1616

1717

1818
def get_next_arg_and_return(
19-
trace_file: str, function_name: str, file_path: str, class_name: str | None = None, num_to_get: int = 256
19+
trace_file: str, benchmark_function_name:str, function_name: str, file_path: str, class_name: str | None = None, num_to_get: int = 256
2020
) -> Generator[Any]:
2121
db = sqlite3.connect(trace_file)
2222
cur = db.cursor()
2323
limit = num_to_get
2424

2525
if class_name is not None:
2626
cursor = cur.execute(
27-
"SELECT * FROM benchmark_function_timings WHERE function_name = ? AND file_path = ? AND class_name = ? LIMIT ?",
28-
(function_name, file_path, class_name, limit),
27+
"SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = ? LIMIT ?",
28+
(benchmark_function_name, function_name, file_path, class_name, limit),
2929
)
3030
else:
3131
cursor = cur.execute(
32-
"SELECT * FROM benchmark_function_timings WHERE function_name = ? AND file_path = ? AND class_name = '' LIMIT ?",
33-
(function_name, file_path, limit),
32+
"SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = '' LIMIT ?",
33+
(benchmark_function_name, function_name, file_path, limit),
3434
)
3535

3636
while (val := cursor.fetchone()) is not None:
@@ -61,6 +61,7 @@ def create_trace_replay_test_code(
6161
"""
6262
assert test_framework in ["pytest", "unittest"]
6363

64+
# Create Imports
6465
imports = f"""import dill as pickle
6566
{"import unittest" if test_framework == "unittest" else ""}
6667
from codeflash.benchmarking.replay_test import get_next_arg_and_return
@@ -82,16 +83,15 @@ def create_trace_replay_test_code(
8283

8384
imports += "\n".join(function_imports)
8485

85-
functions_to_optimize = [func.get("function_name") for func in functions_data
86-
if func.get("function_name") != "__init__"]
86+
functions_to_optimize = sorted({func.get("function_name") for func in functions_data
87+
if func.get("function_name") != "__init__"})
8788
metadata = f"""functions = {functions_to_optimize}
8889
trace_file_path = r"{trace_file}"
8990
"""
90-
9191
# Templates for different types of tests
9292
test_function_body = textwrap.dedent(
9393
"""\
94-
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_path=r"{file_path}", num_to_get={max_run_count}):
94+
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", num_to_get={max_run_count}):
9595
args = pickle.loads(args_pkl)
9696
kwargs = pickle.loads(kwargs_pkl)
9797
ret = {function_name}(*args, **kwargs)
@@ -100,7 +100,7 @@ def create_trace_replay_test_code(
100100

101101
test_method_body = textwrap.dedent(
102102
"""\
103-
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
103+
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
104104
args = pickle.loads(args_pkl)
105105
kwargs = pickle.loads(kwargs_pkl){filter_variables}
106106
function_name = "{orig_function_name}"
@@ -115,7 +115,7 @@ def create_trace_replay_test_code(
115115

116116
test_class_method_body = textwrap.dedent(
117117
"""\
118-
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
118+
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
119119
args = pickle.loads(args_pkl)
120120
kwargs = pickle.loads(kwargs_pkl){filter_variables}
121121
if not args:
@@ -125,13 +125,15 @@ def create_trace_replay_test_code(
125125
)
126126
test_static_method_body = textwrap.dedent(
127127
"""\
128-
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
128+
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
129129
args = pickle.loads(args_pkl)
130130
kwargs = pickle.loads(kwargs_pkl){filter_variables}
131131
ret = {class_name_alias}{method_name}(*args, **kwargs)
132132
"""
133133
)
134134

135+
# Create main body
136+
135137
if test_framework == "unittest":
136138
self = "self"
137139
test_template = "\nclass TestTracedFunctions(unittest.TestCase):\n"
@@ -140,17 +142,20 @@ def create_trace_replay_test_code(
140142
self = ""
141143

142144
for func in functions_data:
145+
143146
module_name = func.get("module_name")
144147
function_name = func.get("function_name")
145148
class_name = func.get("class_name")
146149
file_path = func.get("file_path")
150+
benchmark_function_name = func.get("benchmark_function_name")
147151
function_properties = func.get("function_properties")
148152
if not class_name:
149153
alias = get_function_alias(module_name, function_name)
150154
test_body = test_function_body.format(
155+
benchmark_function_name=benchmark_function_name,
156+
orig_function_name=function_name,
151157
function_name=alias,
152158
file_path=file_path,
153-
orig_function_name=function_name,
154159
max_run_count=max_run_count,
155160
)
156161
else:
@@ -162,6 +167,7 @@ def create_trace_replay_test_code(
162167
method_name = "." + function_name if function_name != "__init__" else ""
163168
if function_properties.is_classmethod:
164169
test_body = test_class_method_body.format(
170+
benchmark_function_name=benchmark_function_name,
165171
orig_function_name=function_name,
166172
file_path=file_path,
167173
class_name_alias=class_name_alias,
@@ -172,6 +178,7 @@ def create_trace_replay_test_code(
172178
)
173179
elif function_properties.is_staticmethod:
174180
test_body = test_static_method_body.format(
181+
benchmark_function_name=benchmark_function_name,
175182
orig_function_name=function_name,
176183
file_path=file_path,
177184
class_name_alias=class_name_alias,
@@ -182,6 +189,7 @@ def create_trace_replay_test_code(
182189
)
183190
else:
184191
test_body = test_method_body.format(
192+
benchmark_function_name=benchmark_function_name,
185193
orig_function_name=function_name,
186194
file_path=file_path,
187195
class_name_alias=class_name_alias,
@@ -217,25 +225,25 @@ def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework
217225
conn = sqlite3.connect(trace_file_path.as_posix())
218226
cursor = conn.cursor()
219227

220-
# Get distinct benchmark names
228+
# Get distinct benchmark file paths
221229
cursor.execute(
222-
"SELECT DISTINCT benchmark_function_name, benchmark_file_path FROM benchmark_function_timings"
230+
"SELECT DISTINCT benchmark_file_path FROM benchmark_function_timings"
223231
)
224-
benchmarks = cursor.fetchall()
232+
benchmark_files = cursor.fetchall()
225233

226-
# Generate a test for each benchmark
227-
for benchmark in benchmarks:
228-
benchmark_function_name, benchmark_file_path = benchmark
229-
# Get functions associated with this benchmark
234+
# Generate a test for each benchmark file
235+
for benchmark_file in benchmark_files:
236+
benchmark_file_path = benchmark_file[0]
237+
# Get all benchmarks and functions associated with this file path
230238
cursor.execute(
231-
"SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_line_number FROM benchmark_function_timings "
232-
"WHERE benchmark_function_name = ? AND benchmark_file_path = ?",
233-
(benchmark_function_name, benchmark_file_path)
239+
"SELECT DISTINCT benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number FROM benchmark_function_timings "
240+
"WHERE benchmark_file_path = ?",
241+
(benchmark_file_path,)
234242
)
235243

236244
functions_data = []
237-
for func_row in cursor.fetchall():
238-
function_name, class_name, module_name, file_path, benchmark_line_number = func_row
245+
for row in cursor.fetchall():
246+
benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number = row
239247
# Add this function to our list
240248
functions_data.append({
241249
"function_name": function_name,
@@ -246,16 +254,15 @@ def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework
246254
"benchmark_file_path": benchmark_file_path,
247255
"benchmark_line_number": benchmark_line_number,
248256
"function_properties": inspect_top_level_functions_or_methods(
249-
file_name=Path(file_path),
250-
function_or_method_name=function_name,
251-
class_name=class_name,
252-
)
257+
file_name=Path(file_path),
258+
function_or_method_name=function_name,
259+
class_name=class_name,
260+
)
253261
})
254262

255263
if not functions_data:
256-
logger.info(f"No functions found for benchmark {benchmark_function_name} in {benchmark_file_path}")
264+
logger.info(f"No benchmark test functions found in {benchmark_file_path}")
257265
continue
258-
259266
# Generate the test code for this benchmark
260267
test_code = create_trace_replay_test_code(
261268
trace_file=trace_file_path.as_posix(),
@@ -265,17 +272,15 @@ def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework
265272
)
266273
test_code = isort.code(test_code)
267274

268-
# Write to file if requested
269-
if output_dir:
270-
name = Path(benchmark_file_path).name.split(".")[0][5:] # remove "test_" from the name since we add it in later
271-
output_file = get_test_file_path(
272-
test_dir=Path(output_dir), function_name=f"{name}_{benchmark_function_name}", test_type="replay"
273-
)
274-
# Write test code to file, parents = true
275-
output_dir.mkdir(parents=True, exist_ok=True)
276-
output_file.write_text(test_code, "utf-8")
277-
count += 1
278-
logger.info(f"Replay test for benchmark `{benchmark_function_name}` in {name} written to {output_file}")
275+
name = Path(benchmark_file_path).name.split(".")[0][5:] # remove "test_" from the name since we add it in later
276+
output_file = get_test_file_path(
277+
test_dir=Path(output_dir), function_name=f"{name}", test_type="replay"
278+
)
279+
# Write test code to file, parents = true
280+
output_dir.mkdir(parents=True, exist_ok=True)
281+
output_file.write_text(test_code, "utf-8")
282+
count += 1
283+
logger.info(f"Replay test for benchmark file `{benchmark_file_path}` in {name} written to {output_file}")
279284

280285
conn.close()
281286

codeflash/models/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ def group_by_benchmarks(self, benchmark_keys:list[BenchmarkKey], benchmark_repla
483483
test_results_by_benchmark = defaultdict(TestResults)
484484
benchmark_module_path = {}
485485
for benchmark_key in benchmark_keys:
486-
benchmark_module_path[benchmark_key] = module_name_from_file_path(benchmark_replay_test_dir.resolve() / f"test_{Path(benchmark_key.file_path).name.split('.')[0][5:]}_{benchmark_key.function_name}__replay_test_", project_root)
486+
benchmark_module_path[benchmark_key] = module_name_from_file_path(benchmark_replay_test_dir.resolve() / f"test_{Path(benchmark_key.file_path).name.split('.')[0][5:]}__replay_test_", project_root)
487487
for test_result in self.test_results:
488488
if (test_result.test_type == TestType.REPLAY_TEST):
489489
for benchmark_key, module_path in benchmark_module_path.items():

0 commit comments

Comments
 (0)