Skip to content

Commit d760ba2

Browse files
committed
Handled nested benchmark text output with tree parsing
1 parent 2d5b019 commit d760ba2

File tree

1 file changed

+56
-28
lines changed

1 file changed

+56
-28
lines changed

pylops_mpi/utils/benchmark.py

Lines changed: 56 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,49 @@
77
# TODO (tharitt): later move to env file or something
88
ENABLE_BENCHMARK = True
99

10+
# Stack of active mark functions for nested support
11+
_mark_func_stack = []
12+
_markers = []
13+
14+
15+
def _parse_output_tree(markers):
16+
output = []
17+
stack = []
18+
i = 0
19+
while i < len(markers):
20+
label, time, level = markers[i]
21+
if label.startswith("[header]"):
22+
output.append(f"{"\t" * (level - 1)}{label}: total runtime: {time:6f}\n")
23+
else:
24+
if stack:
25+
prev_label, prev_time, prev_level = stack[-1]
26+
if prev_level == level:
27+
output.append(f"{"\t" * level}{prev_label}-->{label}: {time - prev_time: 6f}\n")
28+
stack.pop()
29+
30+
# Push to the stack only if it is going deeper or the same level
31+
if i + 1 < len(markers) - 1:
32+
_, _ , next_level = markers[i + 1]
33+
if next_level >= level:
34+
stack.append(markers[i])
35+
i += 1
36+
return output
37+
1038

11-
# This function allows users to measure time arbitary lines of the function
1239
def mark(label):
40+
"""This function allows users to measure time arbitary lines of the function
41+
42+
Parameters
43+
----------
44+
label: :obj:`str`
45+
A label of the mark. This signifies both 1) the end of the
46+
previous mark 2) the beginning of the new mark
47+
"""
1348
if not _mark_func_stack:
1449
raise RuntimeError("mark() called outside of a benchmarked region")
1550
_mark_func_stack[-1](label)
1651

1752

18-
# Stack of active mark functions for nested support
19-
_mark_func_stack = []
20-
21-
2253
def benchmark(func: Optional[Callable] = None,
2354
description="",
2455
save_file=False,
@@ -28,9 +59,9 @@ def benchmark(func: Optional[Callable] = None,
2859
2960
This wrapper measure the start-to-end time of the wrapped function when
3061
decorated without any argument.
31-
It also allows users to put a call to mark() anywhere inside the wrapped function
62+
It also allows users to put a call to mark() anywhere inside the wrapped function
3263
for fine-grain time benchmark. This wrapper defines the local_mark() and push it
33-
the the _mark_func_stack for isolation in case of nested call.
64+
the the _mark_func_stack for isolation in case of nested call.
3465
The user-facing mark() will always call the function at the top of the _mark_func_stack.
3566
3667
Parameters
@@ -58,38 +89,35 @@ def wrapper(*args, **kwargs):
5889
# Here we rely on the closure property of Python.
5990
# This marks will isolate from (shadow) the marks previously
6091
# defined in the function currently on top of the _mark_func_stack.
61-
marks = []
92+
93+
level = len(_mark_func_stack) + 1
94+
# The header is needed for later tree parsing. Here it is allocating its spot.
95+
# the tuple at this index will be replaced after elapsed time is calculated.
96+
_markers.append((f"[header]{description or func.__name__}", None, level))
97+
header_index = len(_markers) - 1
6298

6399
def local_mark(label):
64-
marks.append((label, time.perf_counter()))
100+
_markers.append((label, time.perf_counter(), level))
101+
65102
_mark_func_stack.append(local_mark)
66103

67104
start_time = time.perf_counter()
68105
# the mark() called in wrapped function will now call local_mark
69106
result = func(*args, **kwargs)
70107
end_time = time.perf_counter()
71108

72-
_mark_func_stack.pop()
73-
74-
output = []
75-
# start-to-end time
76109
elapsed = end_time - start_time
110+
_markers[header_index] = (f"[header]{description or func.__name__}", elapsed, level)
111+
112+
# In case of nesting, the wrapped callee must pop its closure from stack so that
113+
# when the callee returns, the wrapped caller operates on its closure (and its level label), which now becomes
114+
# the top of the stack.
115+
_mark_func_stack.pop()
77116

78-
# TODO (tharitt): Both MPI + NCCL collective calls have implicit synchronization inside the stream
79-
# So, output only from rank=0 should suffice. We can add per-rank output later on if makes sense.
80-
if rank == 0:
81-
level = len(_mark_func_stack)
82-
output.append(f"{'---' * level}{description or func.__name__} {elapsed:6f} seconds (rank = {rank})\n")
83-
if marks:
84-
prev_label, prev_t = marks[0]
85-
for label, t in marks[1:]:
86-
output.append(f"{'---' * level}[{prev_label} ---> {label}]: {t - prev_t:.6f}s \n")
87-
prev_label, prev_t = label, t
88-
89-
if save_file:
90-
with open(file_path, "a") as f:
91-
f.write("".join(output))
92-
else:
117+
# finish all the calls
118+
if not _mark_func_stack:
119+
if rank == 0:
120+
output = _parse_output_tree(_markers)
93121
print("".join(output))
94122
return result
95123
return wrapper

0 commit comments

Comments
 (0)