Skip to content

Commit a04d1ff

Browse files
committed
benchmark with logging, improve text output formatting, fix _check_local_shapes
1 parent d760ba2 commit a04d1ff

File tree

3 files changed

+21
-118
lines changed

3 files changed

+21
-118
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ def _check_local_shapes(self, local_shapes):
452452
elif self.partition is Partition.SCATTER:
453453
local_shape = local_shapes[self.rank]
454454
# Check if local shape sum up to global shape and other dimensions align with global shape
455-
if self._allreduce(local_shape[self.axis]) != self.global_shape[self.axis] or \
455+
if self.base_comm.allreduce(local_shape[self.axis]) != self.global_shape[self.axis] or \
456456
not np.array_equal(np.delete(local_shape, self.axis), np.delete(self.global_shape, self.axis)):
457457
raise ValueError(f"Local shapes don't align with the global shape;"
458458
f"{local_shapes} != {self.global_shape}")

pylops_mpi/benchmark/kirchhoff_bench.py

Lines changed: 0 additions & 107 deletions
This file was deleted.

pylops_mpi/utils/benchmark.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import functools
2+
import logging
3+
import sys
24
import time
35
from typing import Callable, Optional
46
from mpi4py import MPI
@@ -7,27 +9,28 @@
79
# TODO (tharitt): later move to env file or something
810
ENABLE_BENCHMARK = True
911

12+
logging.basicConfig(level=logging.INFO, force=True)
1013
# Stack of active mark functions for nested support
1114
_mark_func_stack = []
1215
_markers = []
1316

1417

15-
def _parse_output_tree(markers):
18+
def _parse_output_tree(markers: list[str]):
1619
output = []
1720
stack = []
1821
i = 0
1922
while i < len(markers):
2023
label, time, level = markers[i]
2124
if label.startswith("[header]"):
22-
output.append(f"{"\t" * (level - 1)}{label}: total runtime: {time:6f}\n")
25+
output.append(f"{"\t" * (level - 1)}{label}: total runtime: {time:6f} s\n")
2326
else:
2427
if stack:
2528
prev_label, prev_time, prev_level = stack[-1]
2629
if prev_level == level:
27-
output.append(f"{"\t" * level}{prev_label}-->{label}: {time - prev_time: 6f}\n")
30+
output.append(f"{"\t" * level}{prev_label}-->{label}: {time - prev_time:6f} s\n")
2831
stack.pop()
2932

30-
# Push to the stack only if it is going deeper or the same level
33+
# Push to the stack only if it is going deeper or still at the same level
3134
if i + 1 < len(markers) - 1:
3235
_, _ , next_level = markers[i + 1]
3336
if next_level >= level:
@@ -86,10 +89,6 @@ def decorator(func):
8689
def wrapper(*args, **kwargs):
8790
rank = MPI.COMM_WORLD.Get_rank()
8891

89-
# Here we rely on the closure property of Python.
90-
# This marks will isolate from (shadow) the marks previously
91-
# defined in the function currently on top of the _mark_func_stack.
92-
9392
level = len(_mark_func_stack) + 1
9493
# The header is needed for later tree parsing. Here it is allocating its spot.
9594
# the tuple at this index will be replaced after elapsed time is calculated.
@@ -114,11 +113,22 @@ def local_mark(label):
114113
# the top of the stack.
115114
_mark_func_stack.pop()
116115

117-
# finish all the calls
116+
# all the calls have fininshed
118117
if not _mark_func_stack:
119118
if rank == 0:
120119
output = _parse_output_tree(_markers)
121-
print("".join(output))
120+
logger = logging.getLogger()
121+
# remove the stdout
122+
for h in logger.handlers[:]:
123+
logger.removeHandler(h)
124+
handler = logging.FileHandler(file_path, mode='w') if save_file else logging.StreamHandler(sys.stdout)
125+
handler.setLevel(logging.INFO)
126+
logger.addHandler(handler)
127+
logger.info("".join(output))
128+
logger.removeHandler(handler)
129+
if save_file:
130+
handler.close()
131+
122132
return result
123133
return wrapper
124134
if func is not None:

0 commit comments

Comments
 (0)