Skip to content

Commit 6e66898

Browse files
committed
Fixed benchmark.py, allows to decorate without params, handle nested calls properly
1 parent 7c31951 commit 6e66898

File tree

1 file changed

+74
-40
lines changed

1 file changed

+74
-40
lines changed

pylops_mpi/utils/benchmark.py

Lines changed: 74 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,99 @@
11
import functools
22
import time
3+
from typing import Callable, Optional
4+
from mpi4py import MPI
5+
36

47
# TODO (tharitt): later move to env file or something
58
ENABLE_BENCHMARK = True
69

710

8-
# This function is to be instrumented throughout the targeted function
11+
# This function allows users to measure time arbitary lines of the function
912
def mark(label):
10-
if _current_mark_func is not None:
11-
_current_mark_func(label)
13+
if not _mark_func_stack:
14+
raise RuntimeError("mark() called outside of a benchmarked region")
15+
_mark_func_stack[-1](label)
1216

1317

14-
# Global hook - this will be re-assigned (points to)
15-
# the function defined in benchmark wrapper
16-
_current_mark_func = None
18+
# Stack of active mark functions for nested support
19+
_mark_func_stack = []
1720

1821

19-
def benchmark(func):
22+
def benchmark(func: Optional[Callable] = None,
23+
description="",
24+
save_file=False,
25+
file_path='benchmark.log'
26+
):
2027
"""A wrapper for code injection for time measurement.
2128
22-
This wrapper allows users to put a call to mark()
23-
anywhere inside the wrapped function. The function mark()
24-
is defined in the global scope to be a placeholder for the targeted
25-
function to import. This wrapper will make it points to local_mark() defined
26-
in this function. Therefore, the wrapped function will be able call
27-
local_mark(). All the context for local_mark() like mark list can be
28-
hidden from users and thus provide clean interface.
29+
This wrapper measure the start-to-end time of the wrapped function when
30+
decorated without any argument.
31+
It also allows users to put a call to mark() anywhere inside the wrapped function
32+
for fine-grain time benchmark. This wrapper defines the local_mark() and push it
33+
the the _mark_func_stack for isolation in case of nested call.
34+
The user-facing mark() will always call the function at the top of the _mark_func_stack.
2935
3036
Parameters
3137
----------
3238
func : :obj:`callable`, optional
33-
Function to be decorated.
39+
Function to be decorated. Defaults to ``None``.
40+
description : :obj:`str`, optional
41+
Description for the output text. Defaults to ``''``.
42+
save_file : :obj:`bool`, optional
43+
Flag for saving file to a disk. Otherwise, the result will output to stdout. Defaults to ``False``
44+
file_path : :obj:`str`, optional
45+
File path for saving the output. Defaults to ``benchmark.log``
46+
3447
"""
3548

3649
# Zero-overhead
3750
if not ENABLE_BENCHMARK:
3851
return func
3952

4053
@functools.wraps(func)
41-
def wrapper(*args, **kwargs):
42-
marks = []
43-
44-
# currently this simply record the user-define label and record time
45-
def local_mark(label):
46-
marks.append((label, time.perf_counter()))
47-
48-
global _current_mark_func
49-
_current_mark_func = local_mark
50-
51-
# the mark() called in wrapped function will now call local_mark
52-
result = func(*args, **kwargs)
53-
# clean up to original state
54-
_current_mark_func = None
55-
56-
# TODO (tharitt): maybe changing to saving results to file instead
57-
if marks:
58-
prev_label, prev_t = marks[0]
59-
print(f"[BENCH] {prev_label}: 0.000000s")
60-
for label, t in marks[1:]:
61-
print(f"[BENCH] {label}: {t - prev_t:.6f}s since '{prev_label}'")
62-
prev_label, prev_t = label, t
63-
return result
64-
65-
return wrapper
54+
def decorator(func):
55+
def wrapper(*args, **kwargs):
56+
rank = MPI.COMM_WORLD.Get_rank()
57+
58+
# Here we rely on the closure property of Python.
59+
# This marks will isolate from (shadow) the marks previously
60+
# defined in the function currently on top of the _mark_func_stack.
61+
marks = []
62+
63+
def local_mark(label):
64+
marks.append((label, time.perf_counter()))
65+
_mark_func_stack.append(local_mark)
66+
67+
start_time = time.perf_counter()
68+
# the mark() called in wrapped function will now call local_mark
69+
result = func(*args, **kwargs)
70+
end_time = time.perf_counter()
71+
72+
_mark_func_stack.pop()
73+
74+
output = []
75+
# start-to-end time
76+
elapsed = end_time - start_time
77+
78+
# TODO (tharitt): Both MPI + NCCL collective calls have implicit synchronization inside the stream
79+
# So, output only from rank=0 should suffice. We can add per-rank output later on if makes sense.
80+
if rank == 0:
81+
level = len(_mark_func_stack)
82+
output.append(f"{'---' * level}{description or func.__name__} {elapsed:6f} seconds (rank = {rank})\n")
83+
if marks:
84+
prev_label, prev_t = marks[0]
85+
for label, t in marks[1:]:
86+
output.append(f"{'---' * level}[{prev_label} ---> {label}]: {t - prev_t:.6f}s \n")
87+
prev_label, prev_t = label, t
88+
89+
if save_file:
90+
with open(file_path, "a") as f:
91+
f.write("".join(output))
92+
else:
93+
print("".join(output))
94+
return result
95+
return wrapper
96+
if func is not None:
97+
return decorator(func)
98+
99+
return decorator

0 commit comments

Comments
 (0)