Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions codeflash/benchmarking/plugin/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self) -> None:
self.project_root = None
self.benchmark_timings = []

def setup(self, trace_path:str, project_root:str) -> None:
def setup(self, trace_path: str, project_root: str) -> None:
try:
# Open connection
self.project_root = project_root
Expand All @@ -35,7 +35,7 @@ def setup(self, trace_path:str, project_root:str) -> None:
"benchmark_time_ns INTEGER)"
)
self._connection.commit()
self.close() # Reopen only at the end of pytest session
self.close() # Reopen only at the end of pytest session
except Exception as e:
print(f"Database setup error: {e}")
if self._connection:
Expand All @@ -55,14 +55,15 @@ def write_benchmark_timings(self) -> None:
# Insert data into the benchmark_timings table
cur.executemany(
"INSERT INTO benchmark_timings (benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)",
self.benchmark_timings
self.benchmark_timings,
)
self._connection.commit()
self.benchmark_timings = [] # Clear the benchmark timings list
self.benchmark_timings = [] # Clear the benchmark timings list
except Exception as e:
print(f"Error writing to benchmark timings database: {e}")
self._connection.rollback()
raise

def close(self) -> None:
if self._connection:
self._connection.close()
Expand Down Expand Up @@ -196,12 +197,7 @@ def pytest_sessionfinish(self, session, exitstatus):

@staticmethod
def pytest_addoption(parser):
parser.addoption(
"--codeflash-trace",
action="store_true",
default=False,
help="Enable CodeFlash tracing"
)
parser.addoption("--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing")

@staticmethod
def pytest_plugin_registered(plugin, manager):
Expand All @@ -213,9 +209,9 @@ def pytest_plugin_registered(plugin, manager):
def pytest_configure(config):
"""Register the benchmark marker."""
config.addinivalue_line(
"markers",
"benchmark: mark test as a benchmark that should be run with codeflash tracing"
"markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing"
)

@staticmethod
def pytest_collection_modifyitems(config, items):
# Skip tests that don't have the benchmark fixture
Expand All @@ -224,19 +220,18 @@ def pytest_collection_modifyitems(config, items):

skip_no_benchmark = pytest.mark.skip(reason="Test requires benchmark fixture")
for item in items:
# Check for direct benchmark fixture usage
has_fixture = hasattr(item, "fixturenames") and "benchmark" in item.fixturenames
if hasattr(item, "fixturenames"):
if "benchmark" in item.fixturenames:
continue # Skip current iteration if benchmark fixture is present

# Check for @pytest.mark.benchmark marker
has_marker = False
if hasattr(item, "get_closest_marker"):
marker = item.get_closest_marker("benchmark")
if marker is not None:
has_marker = True
continue # Skip current iteration if benchmark marker is present

# Skip if neither fixture nor marker is present
if not (has_fixture or has_marker):
item.add_marker(skip_no_benchmark)
item.add_marker(skip_no_benchmark)

# Benchmark fixture
class Benchmark:
Expand All @@ -248,16 +243,19 @@ def __call__(self, func, *args, **kwargs):
if args or kwargs:
# Used as benchmark(func, *args, **kwargs)
return self._run_benchmark(func, *args, **kwargs)

# Used as @benchmark decorator
def wrapped_func(*args, **kwargs):
return func(*args, **kwargs)

result = self._run_benchmark(func)
return wrapped_func

def _run_benchmark(self, func, *args, **kwargs):
"""Actual benchmark implementation."""
benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)),
Path(codeflash_benchmark_plugin.project_root))
benchmark_module_path = module_name_from_file_path(
Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root)
)
benchmark_function_name = self.request.node.name
line_number = int(str(sys._getframe(2).f_lineno)) # 2 frames up in the call stack
# Set env vars
Expand All @@ -278,7 +276,8 @@ def _run_benchmark(self, func, *args, **kwargs):
codeflash_trace.function_call_count = 0
# Add to the benchmark timings buffer
codeflash_benchmark_plugin.benchmark_timings.append(
(benchmark_module_path, benchmark_function_name, line_number, end - start))
(benchmark_module_path, benchmark_function_name, line_number, end - start)
)

return result

Expand All @@ -290,4 +289,5 @@ def benchmark(request):

return CodeFlashBenchmarkPlugin.Benchmark(request)


codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()
Loading