File tree Expand file tree Collapse file tree 2 files changed +20
-17
lines changed Expand file tree Collapse file tree 2 files changed +20
-17
lines changed Original file line number Diff line number Diff line change 11__all__ = [
22 "_prepare_nccl_allgather_inputs" ,
33 "_unroll_nccl_allgather_recv" ,
4+ "_nccl_sync" ,
45 "initialize_nccl_comm" ,
56 "nccl_split" ,
67 "nccl_allgather" ,
1920import cupy as cp
2021import cupy .cuda .nccl as nccl
2122
22-
2323cupy_to_nccl_dtype = {
2424 "float32" : nccl .NCCL_FLOAT32 ,
2525 "float64" : nccl .NCCL_FLOAT64 ,
@@ -63,6 +63,13 @@ def _nccl_buf_size(buf, count=None):
6363 return count if count else buf .size
6464
6565
66+ def _nccl_sync ():
67+ """A thin wrapper of CuPy's synchronization for protected import"""
68+ if cp .cuda .runtime .getDeviceCount () == 0 :
69+ return
70+ cp .cuda .runtime .deviceSynchronize ()
71+
72+
6673def _prepare_nccl_allgather_inputs (send_buf , send_buf_shapes ) -> Tuple [cp .ndarray , cp .ndarray ]:
6774 r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather)
6875
Original file line number Diff line number Diff line change 33import time
44from typing import Callable , Optional , List
55from mpi4py import MPI
6- from pylops .utils import deps
7-
8- cupy_message = deps .cupy_import ("benchmark module" )
9- if cupy_message is None :
10- import cupy as cp
11- if cp .cuda .runtime .getDeviceCount () == 0 :
12- has_cupy = False
13- print (UserWarning ("CuPy is installed, but no CUDA-capable device is available." ))
14- else :
15- has_cupy = True
16- else :
17- has_cupy = False
186
7+ from pylops .utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils
8+ from pylops_mpi .utils import deps
9+
10+ cupy_message = pylops_deps .cupy_import ("the benchmark module" )
11+ nccl_message = deps .nccl_import ("the benchmark module" )
12+
13+ if nccl_message is None and cupy_message is None :
14+ from pylops_mpi .utils ._nccl import _nccl_sync
15+ else :
16+ def _nccl_sync ():
17+ pass
1918
2019# TODO (tharitt): later move to env file or something
2120ENABLE_BENCHMARK = True
@@ -65,10 +64,7 @@ def _parse_output_tree(markers: List[str]):
6564
6665def _sync ():
6766 """Synchronize all MPI processes or CUDA Devices"""
68- if has_cupy :
69- # this is ok to call even if CUDA runtime is not initialized
70- cp .cuda .runtime .deviceSynchronize ()
71-
67+ _nccl_sync ()
7268 MPI .COMM_WORLD .Barrier ()
7369
7470
You can’t perform that action at this time.
0 commit comments