Skip to content

Commit 1f3fcff

Browse files
committed
Support recursive functions, and @benchmark / @pytest.mark.benchmark ways of using benchmark. created tests for all of them
1 parent c82a3a3 commit 1f3fcff

File tree

8 files changed

+194
-57
lines changed

8 files changed

+194
-57
lines changed

code_to_optimize/bubble_sort_codeflash_trace.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,24 @@ def sorter(arr):
99
arr[j + 1] = temp
1010
return arr
1111

12+
@codeflash_trace
13+
def recursive_bubble_sort(arr, n=None):
14+
# Initialize n if not provided
15+
if n is None:
16+
n = len(arr)
17+
18+
# Base case: if n is 1, the array is already sorted
19+
if n == 1:
20+
return arr
21+
22+
# One pass of bubble sort - move the largest element to the end
23+
for i in range(n - 1):
24+
if arr[i] > arr[i + 1]:
25+
arr[i], arr[i + 1] = arr[i + 1], arr[i]
26+
27+
# Recursively sort the remaining n-1 elements
28+
return recursive_bubble_sort(arr, n - 1)
29+
1230
class Sorter:
1331
@codeflash_trace
1432
def __init__(self, arr):
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from code_to_optimize.bubble_sort_codeflash_trace import recursive_bubble_sort
2+
3+
4+
def test_recursive_sort(benchmark):
5+
result = benchmark(recursive_bubble_sort, list(reversed(range(500))))
6+
assert result == list(range(500))
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import pytest
2+
from code_to_optimize.bubble_sort_codeflash_trace import sorter
3+
4+
def test_benchmark_sort(benchmark):
5+
@benchmark
6+
def do_sort():
7+
sorter(list(reversed(range(500))))
8+
9+
@pytest.mark.benchmark(group="benchmark_decorator")
10+
def test_pytest_mark(benchmark):
11+
benchmark(sorter, list(reversed(range(500))))

codeflash/benchmarking/codeflash_trace.py

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pickle
44
import sqlite3
55
import sys
6+
import threading
67
import time
78
from typing import Callable
89

@@ -18,6 +19,8 @@ def __init__(self) -> None:
1819
self.pickle_count_limit = 1000
1920
self._connection = None
2021
self._trace_path = None
22+
self._thread_local = threading.local()
23+
self._thread_local.active_functions = set()
2124

2225
def setup(self, trace_path: str) -> None:
2326
"""Set up the database connection for direct writing.
@@ -98,23 +101,29 @@ def __call__(self, func: Callable) -> Callable:
98101
The wrapped function
99102
100103
"""
104+
func_id = (func.__module__,func.__name__)
101105
@functools.wraps(func)
102106
def wrapper(*args, **kwargs):
107+
# Initialize thread-local active functions set if it doesn't exist
108+
if not hasattr(self._thread_local, "active_functions"):
109+
self._thread_local.active_functions = set()
110+
# If it's in a recursive function, just return the result
111+
if func_id in self._thread_local.active_functions:
112+
return func(*args, **kwargs)
113+
# Track active functions so we can detect recursive functions
114+
self._thread_local.active_functions.add(func_id)
103115
# Measure execution time
104116
start_time = time.thread_time_ns()
105117
result = func(*args, **kwargs)
106118
end_time = time.thread_time_ns()
107119
# Calculate execution time
108120
execution_time = end_time - start_time
109-
110121
self.function_call_count += 1
111122

112-
# Measure overhead
113-
original_recursion_limit = sys.getrecursionlimit()
114123
# Check if currently in pytest benchmark fixture
115124
if os.environ.get("CODEFLASH_BENCHMARKING", "False") == "False":
125+
self._thread_local.active_functions.remove(func_id)
116126
return result
117-
118127
# Get benchmark info from environment
119128
benchmark_function_name = os.environ.get("CODEFLASH_BENCHMARK_FUNCTION_NAME", "")
120129
benchmark_module_path = os.environ.get("CODEFLASH_BENCHMARK_MODULE_PATH", "")
@@ -125,32 +134,54 @@ def wrapper(*args, **kwargs):
125134
if "." in qualname:
126135
class_name = qualname.split(".")[0]
127136

128-
if self.function_call_count <= self.pickle_count_limit:
137+
# Limit pickle count so memory does not explode
138+
if self.function_call_count > self.pickle_count_limit:
139+
print("Pickle limit reached")
140+
self._thread_local.active_functions.remove(func_id)
141+
overhead_time = time.thread_time_ns() - end_time
142+
self.function_calls_data.append(
143+
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
144+
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
145+
overhead_time, None, None)
146+
)
147+
return result
148+
149+
try:
150+
original_recursion_limit = sys.getrecursionlimit()
151+
sys.setrecursionlimit(10000)
152+
# args = dict(args.items())
153+
# if class_name and func.__name__ == "__init__" and "self" in args:
154+
# del args["self"]
155+
# Pickle the arguments
156+
pickled_args = pickle.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
157+
pickled_kwargs = pickle.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
158+
sys.setrecursionlimit(original_recursion_limit)
159+
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
160+
# Retry with dill if pickle fails. It's slower but more comprehensive
129161
try:
130-
sys.setrecursionlimit(1000000)
131-
args = dict(args.items())
132-
if class_name and func.__name__ == "__init__" and "self" in args:
133-
del args["self"]
134-
# Pickle the arguments
135-
pickled_args = pickle.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
136-
pickled_kwargs = pickle.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
162+
pickled_args = dill.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
163+
pickled_kwargs = dill.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
137164
sys.setrecursionlimit(original_recursion_limit)
138-
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
139-
# we retry with dill if pickle fails. It's slower but more comprehensive
140-
try:
141-
pickled_args = dill.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
142-
pickled_kwargs = dill.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
143-
sys.setrecursionlimit(original_recursion_limit)
144-
145-
except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError) as e:
146-
print(f"Error pickling arguments for function {func.__name__}: {e}")
147-
return result
148165

166+
except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError) as e:
167+
print(f"Error pickling arguments for function {func.__name__}: {e}")
168+
# Add to the list of function calls without pickled args. Used for timing info only
169+
self._thread_local.active_functions.remove(func_id)
170+
overhead_time = time.thread_time_ns() - end_time
171+
self.function_calls_data.append(
172+
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
173+
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
174+
overhead_time, None, None)
175+
)
176+
return result
177+
178+
# Flush to database every 1000 calls
149179
if len(self.function_calls_data) > 1000:
150180
self.write_function_timings()
151-
# Calculate overhead time
152-
overhead_time = time.thread_time_ns() - end_time
153181

182+
# Add to the list of function calls with pickled args, to be used for replay tests
183+
self._thread_local.active_functions.remove(func_id)
184+
overhead_time = time.thread_time_ns() - end_time
154185
self.function_calls_data.append(
155186
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
156187
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,

codeflash/benchmarking/plugin/plugin.py

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
175175
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
176176
# Subtract overhead from total time
177177
overhead = overhead_by_benchmark.get(benchmark_key, 0)
178+
print("benchmark_func:", benchmark_func, "Total time:", time_ns, "Overhead:", overhead, "Result:", time_ns - overhead)
178179
result[benchmark_key] = time_ns - overhead
179180

180181
finally:
@@ -210,61 +211,71 @@ def pytest_plugin_registered(plugin, manager):
210211
manager.unregister(plugin)
211212

212213
@staticmethod
214+
def pytest_configure(config):
215+
"""Register the benchmark marker."""
216+
config.addinivalue_line(
217+
"markers",
218+
"benchmark: mark test as a benchmark that should be run with codeflash tracing"
219+
)
220+
@staticmethod
213221
def pytest_collection_modifyitems(config, items):
214222
# Skip tests that don't have the benchmark fixture
215223
if not config.getoption("--codeflash-trace"):
216224
return
217225

218226
skip_no_benchmark = pytest.mark.skip(reason="Test requires benchmark fixture")
219227
for item in items:
220-
if hasattr(item, "fixturenames") and "benchmark" in item.fixturenames:
221-
continue
222-
item.add_marker(skip_no_benchmark)
228+
# Check for direct benchmark fixture usage
229+
has_fixture = hasattr(item, "fixturenames") and "benchmark" in item.fixturenames
230+
231+
# Check for @pytest.mark.benchmark marker
232+
has_marker = False
233+
if hasattr(item, "get_closest_marker"):
234+
marker = item.get_closest_marker("benchmark")
235+
if marker is not None:
236+
has_marker = True
237+
238+
# Skip if neither fixture nor marker is present
239+
if not (has_fixture or has_marker):
240+
item.add_marker(skip_no_benchmark)
223241

224242
# Benchmark fixture
225243
class Benchmark:
226244
def __init__(self, request):
227245
self.request = request
228246

229247
def __call__(self, func, *args, **kwargs):
230-
"""Handle behaviour for the benchmark fixture in pytest.
231-
232-
For example,
233-
234-
def test_something(benchmark):
235-
benchmark(sorter, [3,2,1])
236-
237-
Args:
238-
func: The function to benchmark (e.g. sorter)
239-
args: The arguments to pass to the function (e.g. [3,2,1])
240-
kwargs: The keyword arguments to pass to the function
241-
242-
Returns:
243-
The return value of the function
244-
a
245-
246-
"""
247-
benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root))
248+
"""Handle both direct function calls and decorator usage."""
249+
if args or kwargs:
250+
# Used as benchmark(func, *args, **kwargs)
251+
return self._run_benchmark(func, *args, **kwargs)
252+
# Used as @benchmark decorator
253+
def wrapped_func(*args, **kwargs):
254+
return func(*args, **kwargs)
255+
result = self._run_benchmark(func)
256+
return wrapped_func
257+
258+
def _run_benchmark(self, func, *args, **kwargs):
259+
"""Actual benchmark implementation."""
260+
benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)),
261+
Path(codeflash_benchmark_plugin.project_root))
248262
benchmark_function_name = self.request.node.name
249-
line_number = int(str(sys._getframe(1).f_lineno)) # 1 frame up in the call stack
250-
251-
# Set env vars so codeflash decorator can identify what benchmark its being run in
263+
line_number = int(str(sys._getframe(2).f_lineno)) # 2 frames up in the call stack
264+
# Set env vars
252265
os.environ["CODEFLASH_BENCHMARK_FUNCTION_NAME"] = benchmark_function_name
253266
os.environ["CODEFLASH_BENCHMARK_MODULE_PATH"] = benchmark_module_path
254267
os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number)
255268
os.environ["CODEFLASH_BENCHMARKING"] = "True"
256-
257-
# Run the function
258-
start = time.perf_counter_ns()
269+
# Run the function
270+
start = time.thread_time_ns()
259271
result = func(*args, **kwargs)
260-
end = time.perf_counter_ns()
261-
272+
end = time.thread_time_ns()
262273
# Reset the environment variable
263274
os.environ["CODEFLASH_BENCHMARKING"] = "False"
264275

265276
# Write function calls
266277
codeflash_trace.write_function_timings()
267-
# Reset function call count after a benchmark is run
278+
# Reset function call count
268279
codeflash_trace.function_call_count = 0
269280
# Add to the benchmark timings buffer
270281
codeflash_benchmark_plugin.benchmark_timings.append(

codeflash/benchmarking/pytest_new_process_trace_benchmarks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
codeflash_benchmark_plugin.setup(trace_file, project_root)
1717
codeflash_trace.setup(trace_file)
1818
exitcode = pytest.main(
19-
[benchmarks_root, "--codeflash-trace", "-p", "no:benchmark", "-s", "-o", "addopts="], plugins=[codeflash_benchmark_plugin]
19+
[benchmarks_root, "--codeflash-trace", "-p", "no:benchmark","-p", "no:codspeed","-p", "no:cov-s", "-o", "addopts="], plugins=[codeflash_benchmark_plugin]
2020
) # Errors will be printed to stdout, not stderr
2121

2222
except Exception as e:

codeflash/benchmarking/replay_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def get_next_arg_and_return(
3434
)
3535

3636
while (val := cursor.fetchone()) is not None:
37-
yield val[9], val[10] # args and kwargs are at indices 7 and 8
37+
yield val[9], val[10] # pickled_args, pickled_kwargs
3838

3939

4040
def get_function_alias(module: str, function_name: str) -> str:

tests/test_trace_benchmarks.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_trace_benchmarks():
3131
function_calls = cursor.fetchall()
3232

3333
# Assert the length of function calls
34-
assert len(function_calls) == 7, f"Expected 6 function calls, but got {len(function_calls)}"
34+
assert len(function_calls) == 8, f"Expected 8 function calls, but got {len(function_calls)}"
3535

3636
bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
3737
process_and_bubble_sort_path = (project_root / "process_and_bubble_sort_codeflash_trace.py").as_posix()
@@ -64,6 +64,10 @@ def test_trace_benchmarks():
6464
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
6565
f"{bubble_sort_path}",
6666
"test_no_func", "tests.pytest.benchmarks_test.test_process_and_sort_example", 8),
67+
68+
("recursive_bubble_sort", "", "code_to_optimize.bubble_sort_codeflash_trace",
69+
f"{bubble_sort_path}",
70+
"test_recursive_sort", "tests.pytest.benchmarks_test.test_recursive_example", 5),
6771
]
6872
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
6973
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
@@ -222,6 +226,62 @@ def test_trace_multithreaded_benchmark() -> None:
222226
# Close connection
223227
conn.close()
224228

229+
finally:
230+
# cleanup
231+
output_file.unlink(missing_ok=True)
232+
233+
def test_trace_benchmark_decorator() -> None:
234+
project_root = Path(__file__).parent.parent / "code_to_optimize"
235+
benchmarks_root = project_root / "tests" / "pytest" / "benchmarks_test_decorator"
236+
tests_root = project_root / "tests"
237+
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
238+
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
239+
assert output_file.exists()
240+
try:
241+
# check contents of trace file
242+
# connect to database
243+
conn = sqlite3.connect(output_file.as_posix())
244+
cursor = conn.cursor()
245+
246+
# Get the count of records
247+
# Get all records
248+
cursor.execute(
249+
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
250+
function_calls = cursor.fetchall()
251+
252+
# Assert the length of function calls
253+
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
254+
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
255+
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
256+
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
257+
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
258+
259+
test_name, total_time, function_time, percent = function_to_results["code_to_optimize.bubble_sort_codeflash_trace.sorter"][0]
260+
assert total_time > 0.0
261+
assert function_time > 0.0
262+
assert percent > 0.0
263+
264+
bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
265+
# Expected function calls
266+
expected_calls = [
267+
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
268+
f"{bubble_sort_path}",
269+
"test_benchmark_sort", "tests.pytest.benchmarks_test_decorator.test_benchmark_decorator", 5),
270+
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
271+
f"{bubble_sort_path}",
272+
"test_pytest_mark", "tests.pytest.benchmarks_test_decorator.test_benchmark_decorator", 11),
273+
]
274+
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
275+
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
276+
assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name"
277+
assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name"
278+
assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path"
279+
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
280+
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
281+
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
282+
# Close connection
283+
conn.close()
284+
225285
finally:
226286
# cleanup
227287
output_file.unlink(missing_ok=True)

0 commit comments

Comments
 (0)