Skip to content

Commit b8bcd29

Browse files
committed
feat: added _bcast to DistributedMixIn and added comms as input for all methods
1 parent a08924b commit b8bcd29

File tree

1 file changed

+35
-19
lines changed

1 file changed

+35
-19
lines changed

pylops_mpi/Distributed.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from mpi4py import MPI
22
from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils
3-
from pylops_mpi.utils._mpi import mpi_allreduce, mpi_allgather, mpi_send, mpi_recv, _prepare_allgather_inputs, _unroll_allgather_recv
3+
from pylops_mpi.utils._mpi import mpi_allreduce, mpi_allgather, mpi_bcast, mpi_send, mpi_recv, _prepare_allgather_inputs, _unroll_allgather_recv
44
from pylops_mpi.utils import deps
55

66
cupy_message = pylops_deps.cupy_import("the DistributedArray module")
77
nccl_message = deps.nccl_import("the DistributedArray module")
88

99
if nccl_message is None and cupy_message is None:
1010
from pylops_mpi.utils._nccl import (
11-
nccl_allgather, nccl_allreduce, nccl_send, nccl_recv
11+
nccl_allgather, nccl_allreduce, nccl_bcast, nccl_send, nccl_recv
1212
)
1313

1414

@@ -22,39 +22,45 @@ class DistributedMixIn:
2222
MPI installation is available, the latter with CuPy arrays when a CUDA-Aware
2323
MPI installation is not available).
2424
"""
25-
def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
25+
def _allreduce(self, base_comm, base_comm_nccl,
26+
send_buf, recv_buf=None, op: MPI.Op = MPI.SUM,
27+
engine="numpy"):
2628
"""Allreduce operation
2729
"""
28-
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
29-
return nccl_allreduce(self.base_comm_nccl, send_buf, recv_buf, op)
30+
if deps.nccl_enabled and base_comm_nccl is not None:
31+
return nccl_allreduce(base_comm_nccl, send_buf, recv_buf, op)
3032
else:
31-
return mpi_allreduce(self.base_comm, send_buf,
32-
recv_buf, self.engine, op)
33+
return mpi_allreduce(base_comm, send_buf,
34+
recv_buf, engine, op)
3335

34-
def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
36+
def _allreduce_subcomm(self, sub_comm, base_comm_nccl,
37+
send_buf, recv_buf=None, op: MPI.Op = MPI.SUM,
38+
engine="numpy"):
3539
"""Allreduce operation with subcommunicator
3640
"""
37-
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
38-
return nccl_allreduce(self.sub_comm, send_buf, recv_buf, op)
41+
if deps.nccl_enabled and base_comm_nccl is not None:
42+
return nccl_allreduce(sub_comm, send_buf, recv_buf, op)
3943
else:
40-
return mpi_allreduce(self.sub_comm, send_buf,
41-
recv_buf, self.engine, op)
44+
return mpi_allreduce(sub_comm, send_buf,
45+
recv_buf, engine, op)
4246

43-
def _allgather(self, send_buf, recv_buf=None):
47+
def _allgather(self, base_comm, base_comm_nccl,
48+
send_buf, recv_buf=None,
49+
engine="numpy"):
4450
"""Allgather operation
4551
"""
46-
if deps.nccl_enabled and self.base_comm_nccl:
52+
if deps.nccl_enabled and base_comm_nccl is not None:
4753
if isinstance(send_buf, (tuple, list, int)):
48-
return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf)
54+
return nccl_allgather(base_comm_nccl, send_buf, recv_buf)
4955
else:
50-
send_shapes = self.base_comm.allgather(send_buf.shape)
56+
send_shapes = base_comm.allgather(send_buf.shape)
5157
(padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine="cupy")
52-
raw_recv = nccl_allgather(self.base_comm_nccl, padded_send, recv_buf if recv_buf else padded_recv)
58+
raw_recv = nccl_allgather(base_comm_nccl, padded_send, recv_buf if recv_buf else padded_recv)
5359
return _unroll_allgather_recv(raw_recv, padded_send.shape, send_shapes)
5460
else:
5561
if isinstance(send_buf, (tuple, list, int)):
56-
return self.base_comm.allgather(send_buf)
57-
return mpi_allgather(self.base_comm, send_buf, recv_buf, self.engine)
62+
return base_comm.allgather(send_buf)
63+
return mpi_allgather(base_comm, send_buf, recv_buf, engine)
5864

5965
def _allgather_subcomm(self, send_buf, recv_buf=None):
6066
"""Allgather operation with subcommunicator
@@ -70,6 +76,16 @@ def _allgather_subcomm(self, send_buf, recv_buf=None):
7076
else:
7177
return mpi_allgather(self.sub_comm, send_buf, recv_buf, self.engine)
7278

79+
def _bcast(self, local_array, index, value):
80+
"""BCast operation
81+
"""
82+
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
83+
nccl_bcast(self.base_comm_nccl, local_array, index, value)
84+
else:
85+
# self.local_array[index] = self.base_comm.bcast(value)
86+
mpi_bcast(self.base_comm, self.rank, self.local_array, index, value,
87+
engine=self.engine)
88+
7389
def _send(self, send_buf, dest, count=None, tag=0):
7490
"""Send operation
7591
"""

0 commit comments

Comments
 (0)