Skip to content

Commit 5e4e7aa

Browse files
committed
Fix sphinx complains by moving cupy sync to _nccl for protected import mech
1 parent c0991fd commit 5e4e7aa

File tree

2 files changed

+20
-17
lines changed

2 files changed

+20
-17
lines changed

pylops_mpi/utils/_nccl.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
__all__ = [
22
"_prepare_nccl_allgather_inputs",
33
"_unroll_nccl_allgather_recv",
4+
"_nccl_sync",
45
"initialize_nccl_comm",
56
"nccl_split",
67
"nccl_allgather",
@@ -19,7 +20,6 @@
1920
import cupy as cp
2021
import cupy.cuda.nccl as nccl
2122

22-
2323
cupy_to_nccl_dtype = {
2424
"float32": nccl.NCCL_FLOAT32,
2525
"float64": nccl.NCCL_FLOAT64,
@@ -63,6 +63,13 @@ def _nccl_buf_size(buf, count=None):
6363
return count if count else buf.size
6464

6565

66+
def _nccl_sync():
67+
"""A thin wrapper of CuPy's synchronization for protected import"""
68+
if cp.cuda.runtime.getDeviceCount() == 0:
69+
return
70+
cp.cuda.runtime.deviceSynchronize()
71+
72+
6673
def _prepare_nccl_allgather_inputs(send_buf, send_buf_shapes) -> Tuple[cp.ndarray, cp.ndarray]:
6774
r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather)
6875

pylops_mpi/utils/benchmark.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,18 @@
33
import time
44
from typing import Callable, Optional, List
55
from mpi4py import MPI
6-
from pylops.utils import deps
7-
8-
cupy_message = deps.cupy_import("benchmark module")
9-
if cupy_message is None:
10-
import cupy as cp
11-
if cp.cuda.runtime.getDeviceCount() == 0:
12-
has_cupy = False
13-
print(UserWarning("CuPy is installed, but no CUDA-capable device is available."))
14-
else:
15-
has_cupy = True
16-
else:
17-
has_cupy = False
186

7+
from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils
8+
from pylops_mpi.utils import deps
9+
10+
cupy_message = pylops_deps.cupy_import("the benchmark module")
11+
nccl_message = deps.nccl_import("the benchmark module")
12+
13+
if nccl_message is None and cupy_message is None:
14+
from pylops_mpi.utils._nccl import _nccl_sync
15+
else:
16+
def _nccl_sync():
17+
pass
1918

2019
# TODO (tharitt): later move to env file or something
2120
ENABLE_BENCHMARK = True
@@ -65,10 +64,7 @@ def _parse_output_tree(markers: List[str]):
6564

6665
def _sync():
6766
"""Synchronize all MPI processes or CUDA Devices"""
68-
if has_cupy:
69-
# this is ok to call even if CUDA runtime is not initialized
70-
cp.cuda.runtime.deviceSynchronize()
71-
67+
_nccl_sync()
7268
MPI.COMM_WORLD.Barrier()
7369

7470

0 commit comments

Comments
 (0)