Skip to content

Commit 21eec69

Browse files
committed
fix synchronization in GPU-CPU Host
1 parent 8f12b78 commit 21eec69

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

pylops_mpi/utils/benchmark.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@
33
import time
44
from typing import Callable, Optional, List
55
from mpi4py import MPI
6+
from pylops.utils import deps
7+
8+
cupy_message = deps.cupy_import("benchmark module")
9+
if cupy_message is None:
10+
import cupy as cp
11+
has_cupy = True
12+
else:
13+
has_cupy = False
614

715

816
# TODO (tharitt): later move to env file or something
@@ -48,6 +56,16 @@ def _parse_output_tree(markers: List[str]):
4856
return output
4957

5058

59+
def _sync():
60+
"""Synchronize all MPI processes or CUDA Devices"""
61+
if has_cupy:
62+
cp.cuda.get_current_stream().synchronize()
63+
# this is ok to call even if CUDA runtime is not initialized
64+
cp.cuda.runtime.deviceSynchronize()
65+
66+
MPI.COMM_WORLD.Barrier()
67+
68+
5169
def mark(label: str):
5270
"""This function allows users to measure time arbitary lines of the function
5371
@@ -108,9 +126,11 @@ def local_mark(label):
108126

109127
_mark_func_stack.append(local_mark)
110128

129+
_sync()
111130
start_time = time.perf_counter()
112131
# the mark() called in wrapped function will now call local_mark
113132
result = func(*args, **kwargs)
133+
_sync()
114134
end_time = time.perf_counter()
115135

116136
elapsed = end_time - start_time

0 commit comments

Comments
 (0)