11from mpi4py import MPI
22from pylops .utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils
3- from pylops_mpi .utils ._mpi import mpi_allreduce , mpi_allgather , mpi_send , mpi_recv , _prepare_allgather_inputs , _unroll_allgather_recv
3+ from pylops_mpi .utils ._mpi import mpi_allreduce , mpi_allgather , mpi_bcast , mpi_send , mpi_recv , _prepare_allgather_inputs , _unroll_allgather_recv
44from pylops_mpi .utils import deps
55
66cupy_message = pylops_deps .cupy_import ("the DistributedArray module" )
77nccl_message = deps .nccl_import ("the DistributedArray module" )
88
99if nccl_message is None and cupy_message is None :
1010 from pylops_mpi .utils ._nccl import (
11- nccl_allgather , nccl_allreduce , nccl_send , nccl_recv
11+ nccl_allgather , nccl_allreduce , nccl_bcast , nccl_send , nccl_recv
1212 )
1313
1414
@@ -22,39 +22,45 @@ class DistributedMixIn:
2222 MPI installation is available, the latter with CuPy arrays when a CUDA-Aware
2323 MPI installation is not available).
2424 """
25- def _allreduce (self , send_buf , recv_buf = None , op : MPI .Op = MPI .SUM ):
25+ def _allreduce (self , base_comm , base_comm_nccl ,
26+ send_buf , recv_buf = None , op : MPI .Op = MPI .SUM ,
27+ engine = "numpy" ):
2628 """Allreduce operation
2729 """
28- if deps .nccl_enabled and getattr ( self , " base_comm_nccl" ) :
29- return nccl_allreduce (self . base_comm_nccl , send_buf , recv_buf , op )
30+ if deps .nccl_enabled and base_comm_nccl is not None :
31+ return nccl_allreduce (base_comm_nccl , send_buf , recv_buf , op )
3032 else :
31- return mpi_allreduce (self . base_comm , send_buf ,
32- recv_buf , self . engine , op )
33+ return mpi_allreduce (base_comm , send_buf ,
34+ recv_buf , engine , op )
3335
34- def _allreduce_subcomm (self , send_buf , recv_buf = None , op : MPI .Op = MPI .SUM ):
36+ def _allreduce_subcomm (self , sub_comm , base_comm_nccl ,
37+ send_buf , recv_buf = None , op : MPI .Op = MPI .SUM ,
38+ engine = "numpy" ):
3539 """Allreduce operation with subcommunicator
3640 """
37- if deps .nccl_enabled and getattr ( self , " base_comm_nccl" ) :
38- return nccl_allreduce (self . sub_comm , send_buf , recv_buf , op )
41+ if deps .nccl_enabled and base_comm_nccl is not None :
42+ return nccl_allreduce (sub_comm , send_buf , recv_buf , op )
3943 else :
40- return mpi_allreduce (self . sub_comm , send_buf ,
41- recv_buf , self . engine , op )
44+ return mpi_allreduce (sub_comm , send_buf ,
45+ recv_buf , engine , op )
4246
43- def _allgather (self , send_buf , recv_buf = None ):
47+ def _allgather (self , base_comm , base_comm_nccl ,
48+ send_buf , recv_buf = None ,
49+ engine = "numpy" ):
4450 """Allgather operation
4551 """
46- if deps .nccl_enabled and self . base_comm_nccl :
52+ if deps .nccl_enabled and base_comm_nccl is not None :
4753 if isinstance (send_buf , (tuple , list , int )):
48- return nccl_allgather (self . base_comm_nccl , send_buf , recv_buf )
54+ return nccl_allgather (base_comm_nccl , send_buf , recv_buf )
4955 else :
50- send_shapes = self . base_comm .allgather (send_buf .shape )
56+ send_shapes = base_comm .allgather (send_buf .shape )
5157 (padded_send , padded_recv ) = _prepare_allgather_inputs (send_buf , send_shapes , engine = "cupy" )
52- raw_recv = nccl_allgather (self . base_comm_nccl , padded_send , recv_buf if recv_buf else padded_recv )
58+ raw_recv = nccl_allgather (base_comm_nccl , padded_send , recv_buf if recv_buf else padded_recv )
5359 return _unroll_allgather_recv (raw_recv , padded_send .shape , send_shapes )
5460 else :
5561 if isinstance (send_buf , (tuple , list , int )):
56- return self . base_comm .allgather (send_buf )
57- return mpi_allgather (self . base_comm , send_buf , recv_buf , self . engine )
62+ return base_comm .allgather (send_buf )
63+ return mpi_allgather (base_comm , send_buf , recv_buf , engine )
5864
5965 def _allgather_subcomm (self , send_buf , recv_buf = None ):
6066 """Allgather operation with subcommunicator
@@ -70,6 +76,16 @@ def _allgather_subcomm(self, send_buf, recv_buf=None):
7076 else :
7177 return mpi_allgather (self .sub_comm , send_buf , recv_buf , self .engine )
7278
79+ def _bcast (self , local_array , index , value ):
80+ """BCast operation
81+ """
82+ if deps .nccl_enabled and getattr (self , "base_comm_nccl" ):
83+ nccl_bcast (self .base_comm_nccl , local_array , index , value )
84+ else :
85+ # self.local_array[index] = self.base_comm.bcast(value)
86+ mpi_bcast (self .base_comm , self .rank , self .local_array , index , value ,
87+ engine = self .engine )
88+
7389 def _send (self , send_buf , dest , count = None , tag = 0 ):
7490 """Send operation
7591 """
0 commit comments