Skip to content

Commit 3dbecfe

Browse files
committed
first version benhmark wrapper
1 parent 19e873a commit 3dbecfe

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

pylops_mpi/utils/benchmark.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import functools
2+
import time
3+
4+
# TODO (tharitt): later move to env file or something
5+
ENABLE_BENCHMARK = True
6+
7+
# This function is to be instrumented throughout the targeted function
8+
def mark(label):
9+
if _current_mark_func is not None:
10+
_current_mark_func(label)
11+
12+
# Global hook - this will be re-assigned (points to)
13+
# the function defined in benchmark wrapper
14+
_current_mark_func = None
15+
16+
def benchmark(func):
17+
"""A wrapper for code injection for time measurement.
18+
19+
This wrapper allows users to put a call to mark()
20+
anywhere inside the wrapped function. The function mark()
21+
is defined in the global scope to be a placeholder for the targeted
22+
function to import. This wrapper will make it points to local_mark() defined
23+
in this function. Therefore, the wrapped function will be able call
24+
local_mark(). All the context for local_mark() like mark list can be
25+
hidden from users and thus provide clean interface.
26+
27+
Parameters
28+
----------
29+
func : :obj:`callable`, optional
30+
Function to be decorated.
31+
"""
32+
33+
# Zero-overhead
34+
if not ENABLE_BENCHMARK:
35+
return func
36+
37+
@functools.wraps(func)
38+
def wrapper(*args, **kwargs):
39+
marks = []
40+
41+
# currently this simply record the user-define label and record time
42+
def local_mark(label):
43+
marks.append((label, time.perf_counter()))
44+
45+
global _current_mark_func
46+
_current_mark_func = local_mark
47+
48+
# the mark() called in wrapped function will now call local_mark
49+
result = func(*args, **kwargs)
50+
# clean up to original state
51+
_current_mark_func = None
52+
53+
# TODO (tharitt): maybe changing to saving results to file instead
54+
if marks:
55+
prev_label, prev_t = marks[0]
56+
print(f"[BENCH] {prev_label}: 0.000000s")
57+
for label, t in marks[1:]:
58+
print(f"[BENCH] {label}: {t - prev_t:.6f}s since '{prev_label}'")
59+
prev_label, prev_t = label, t
60+
return result
61+
62+
return wrapper

0 commit comments

Comments
 (0)