Skip to content

Commit bf6acc9

Browse files
committed
include benchmark context if applicable
1 parent 240f507 commit bf6acc9

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

codeflash/benchmarking/pytest_new_process_trace_benchmarks.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@
3030
"-s",
3131
"-o",
3232
"addopts=",
33-
"-W",
34-
"ignore::pytest.PytestAssertRewriteWarning",
3533
],
3634
plugins=[codeflash_benchmark_plugin],
3735
) # Errors will be printed to stdout, not stderr

codeflash/benchmarking/replay_test.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import re
34
import sqlite3
45
import textwrap
56
from pathlib import Path
@@ -14,6 +15,8 @@
1415
if TYPE_CHECKING:
1516
from collections.abc import Generator
1617

18+
benchmark_context_cleaner = re.compile(r"[^a-zA-Z0-9_]+")
19+
1720

1821
def get_next_arg_and_return(
1922
trace_file: str,
@@ -46,6 +49,16 @@ def get_function_alias(module: str, function_name: str) -> str:
4649
return "_".join(module.split(".")) + "_" + function_name
4750

4851

52+
def get_unique_test_name(module: str, function_name: str, benchmark_name: str, class_name: str | None = None) -> str:
53+
clean_benchmark = benchmark_context_cleaner.sub("_", benchmark_name).strip("_")
54+
55+
base_alias = get_function_alias(module, function_name)
56+
if class_name:
57+
class_alias = get_function_alias(module, class_name)
58+
return f"{class_alias}_{function_name}_{clean_benchmark}"
59+
return f"{base_alias}_{clean_benchmark}"
60+
61+
4962
def create_trace_replay_test_code(
5063
trace_file: str,
5164
functions_data: list[dict[str, Any]],
@@ -209,7 +222,8 @@ def create_trace_replay_test_code(
209222
formatted_test_body = textwrap.indent(test_body, " " if test_framework == "unittest" else " ")
210223

211224
test_template += " " if test_framework == "unittest" else ""
212-
test_template += f"def test_{alias}({self}):\n{formatted_test_body}\n"
225+
unique_test_name = get_unique_test_name(module_name, function_name, benchmark_function_name, class_name)
226+
test_template += f"def test_{unique_test_name}({self}):\n{formatted_test_body}\n"
213227

214228
return imports + "\n" + metadata + "\n" + test_template
215229

@@ -294,3 +308,4 @@ def generate_replay_test(
294308
logger.info(f"Error generating replay tests: {e}")
295309

296310
return count
311+

codeflash/models/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ def group_by_benchmarks(
497497
benchmark_replay_test_dir.resolve()
498498
/ f"test_{benchmark_key.module_path.replace('.', '_')}__replay_test_",
499499
project_root,
500+
traverse_up=True,
500501
)
501502
for test_result in self.test_results:
502503
if test_result.test_type == TestType.REPLAY_TEST:

0 commit comments

Comments
 (0)