Skip to content

Commit 2c67755

Browse files
committed
feat: finalized passing parameters to all methods in Distributed
1 parent a80f00e commit 2c67755

File tree

2 files changed

+142
-33
lines changed

2 files changed

+142
-33
lines changed

pylops_mpi/Distributed.py

Lines changed: 126 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
from typing import Any, NewType, Optional, Union
2+
13
from mpi4py import MPI
4+
from pylops.utils import NDArray
25
from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils
36
from pylops_mpi.utils._mpi import mpi_allreduce, mpi_allgather, mpi_bcast, mpi_send, mpi_recv, _prepare_allgather_inputs, _unroll_allgather_recv
47
from pylops_mpi.utils import deps
@@ -10,6 +13,11 @@
1013
from pylops_mpi.utils._nccl import (
1114
nccl_allgather, nccl_allreduce, nccl_bcast, nccl_send, nccl_recv
1215
)
16+
from cupy.cuda.nccl import NcclCommunicator
17+
else:
18+
NcclCommunicator = Any
19+
20+
NcclCommunicatorType = NewType("NcclCommunicator", NcclCommunicator)
1321

1422

1523
class DistributedMixIn:
@@ -23,10 +31,14 @@ class DistributedMixIn:
2331
MPI installation is not available).
2432
2533
"""
26-
def _allreduce(self, base_comm, base_comm_nccl,
27-
send_buf, recv_buf=None,
34+
def _allreduce(self,
35+
base_comm: MPI.Comm,
36+
base_comm_nccl: NcclCommunicatorType,
37+
send_buf: NDArray,
38+
recv_buf: Optional[NDArray] = None,
2839
op: MPI.Op = MPI.SUM,
29-
engine="numpy"):
40+
engine: str = "numpy",
41+
) -> NDArray:
3042
"""Allreduce operation
3143
3244
Parameters
@@ -58,10 +70,14 @@ def _allreduce(self, base_comm, base_comm_nccl,
5870
return mpi_allreduce(base_comm, send_buf,
5971
recv_buf, engine, op)
6072

61-
def _allreduce_subcomm(self, sub_comm, base_comm_nccl,
62-
send_buf, recv_buf=None,
73+
def _allreduce_subcomm(self,
74+
sub_comm: MPI.Comm,
75+
base_comm_nccl: NcclCommunicatorType,
76+
send_buf: NDArray,
77+
recv_buf: Optional[NDArray] = None,
6378
op: MPI.Op = MPI.SUM,
64-
engine="numpy"):
79+
engine: str = "numpy",
80+
) -> NDArray:
6581
"""Allreduce operation with subcommunicator
6682
6783
Parameters
@@ -93,15 +109,19 @@ def _allreduce_subcomm(self, sub_comm, base_comm_nccl,
93109
return mpi_allreduce(sub_comm, send_buf,
94110
recv_buf, engine, op)
95111

96-
def _allgather(self, base_comm, base_comm_nccl,
97-
send_buf, recv_buf=None,
98-
engine="numpy"):
112+
def _allgather(self,
113+
base_comm: MPI.Comm,
114+
base_comm_nccl: NcclCommunicatorType,
115+
send_buf: NDArray,
116+
recv_buf: Optional[NDArray] = None,
117+
engine: str = "numpy",
118+
) -> NDArray:
99119
"""Allgather operation
100120
101121
Parameters
102122
----------
103-
sub_comm : :obj:`MPI.Comm`
104-
MPI Subcommunicator.
123+
base_comm : :obj:`MPI.Comm`
124+
Base MPI Communicator.
105125
base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`
106126
NCCL Communicator.
107127
send_buf: :obj: `numpy.ndarray` or `cupy.ndarray`
@@ -131,41 +151,119 @@ def _allgather(self, base_comm, base_comm_nccl,
131151
return base_comm.allgather(send_buf)
132152
return mpi_allgather(base_comm, send_buf, recv_buf, engine)
133153

134-
def _allgather_subcomm(self, send_buf, recv_buf=None):
154+
def _allgather_subcomm(self,
155+
sub_comm: MPI.Comm,
156+
base_comm_nccl: NcclCommunicatorType,
157+
send_buf: NDArray,
158+
recv_buf: Optional[NDArray] = None,
159+
engine: str = "numpy",
160+
) -> NDArray:
135161
"""Allgather operation with subcommunicator
162+
163+
Parameters
164+
----------
165+
sub_comm : :obj:`MPI.Comm`
166+
MPI Subcommunicator.
167+
base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`
168+
NCCL Communicator.
169+
send_buf: :obj: `numpy.ndarray` or `cupy.ndarray`
170+
A buffer containing the data to be sent by this rank.
171+
recv_buf : :obj: `numpy.ndarray` or `cupy.ndarray`, optional
172+
The buffer to store the result of the gathering. If None,
173+
a new buffer will be allocated with the appropriate shape.
174+
engine : :obj:`str`, optional
175+
Engine used to store array (``numpy`` or ``cupy``)
176+
177+
Returns
178+
-------
179+
recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
180+
A buffer containing the gathered data from all ranks.
181+
136182
"""
137-
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
183+
if deps.nccl_enabled and base_comm_nccl is not None:
138184
if isinstance(send_buf, (tuple, list, int)):
139-
return nccl_allgather(self.sub_comm, send_buf, recv_buf)
185+
return nccl_allgather(sub_comm, send_buf, recv_buf)
140186
else:
141-
send_shapes = self._allgather_subcomm(send_buf.shape)
187+
send_shapes = sub_comm._allgather_subcomm(send_buf.shape)
142188
(padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine="cupy")
143-
raw_recv = nccl_allgather(self.sub_comm, padded_send, recv_buf if recv_buf else padded_recv)
189+
raw_recv = nccl_allgather(sub_comm, padded_send, recv_buf if recv_buf else padded_recv)
144190
return _unroll_allgather_recv(raw_recv, padded_send.shape, send_shapes)
145191
else:
146-
return mpi_allgather(self.sub_comm, send_buf, recv_buf, self.engine)
192+
return mpi_allgather(sub_comm, send_buf, recv_buf, engine)
147193

148-
def _bcast(self, local_array, index, value):
194+
def _bcast(self,
195+
base_comm: MPI.Comm,
196+
base_comm_nccl: NcclCommunicatorType,
197+
rank : int,
198+
local_array: NDArray,
199+
index: int,
200+
value: Union[int, NDArray],
201+
engine: str = "numpy",
202+
) -> None:
149203
"""BCast operation
204+
205+
Parameters
206+
----------
207+
base_comm : :obj:`MPI.Comm`
208+
Base MPI Communicator.
209+
base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`
210+
NCCL Communicator.
211+
rank : :obj:`int`
212+
Rank.
213+
local_array : :obj:`numpy.ndarray`
214+
Localy array to be broadcasted.
215+
index : :obj:`int` or :obj:`slice`
216+
Represents the index positions where a value needs to be assigned.
217+
value : :obj:`int` or :obj:`numpy.ndarray`
218+
Represents the value that will be assigned to the local array at
219+
the specified index positions.
220+
engine : :obj:`str`, optional
221+
Engine used to store array (``numpy`` or ``cupy``)
222+
150223
"""
151-
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
152-
nccl_bcast(self.base_comm_nccl, local_array, index, value)
224+
if deps.nccl_enabled and base_comm_nccl is not None:
225+
nccl_bcast(base_comm_nccl, local_array, index, value)
153226
else:
154-
# self.local_array[index] = self.base_comm.bcast(value)
155-
mpi_bcast(self.base_comm, self.rank, self.local_array, index, value,
156-
engine=self.engine)
227+
mpi_bcast(base_comm, rank, local_array, index, value,
228+
engine=engine)
157229

158-
def _send(self, send_buf, dest, count=None, tag=0):
230+
def _send(self,
231+
base_comm: MPI.Comm,
232+
base_comm_nccl: NcclCommunicatorType,
233+
send_buf: NDArray,
234+
dest: int,
235+
count: Optional[int] = None,
236+
tag: int = 0,
237+
engine: str = "numpy",
238+
) -> None:
159239
"""Send operation
240+
241+
Parameters
242+
----------
243+
base_comm : :obj:`MPI.Comm`
244+
Base MPI Communicator.
245+
base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`
246+
NCCL Communicator.
247+
send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
248+
The array containing data to send.
249+
dest: :obj:`int`
250+
The rank of the destination.
251+
count : :obj:`int`
252+
Number of elements to send from `send_buf`.
253+
tag : :obj:`int`
254+
Tag of the message to be sent.
255+
engine : :obj:`str`, optional
256+
Engine used to store array (``numpy`` or ``cupy``)
257+
160258
"""
161-
if deps.nccl_enabled and self.base_comm_nccl:
259+
if deps.nccl_enabled and base_comm_nccl is not None:
162260
if count is None:
163261
count = send_buf.size
164-
nccl_send(self.base_comm_nccl, send_buf, dest, count)
262+
nccl_send(base_comm_nccl, send_buf, dest, count)
165263
else:
166-
mpi_send(self.base_comm,
264+
mpi_send(base_comm,
167265
send_buf, dest, count, tag=tag,
168-
engine=self.engine)
266+
engine=engine)
169267

170268
def _recv(self, recv_buf=None, source=0, count=None, tag=0):
171269
"""Receive operation

pylops_mpi/DistributedArray.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,9 @@ def __setitem__(self, index, value):
204204
the specified index positions.
205205
"""
206206
if self.partition is Partition.BROADCAST:
207-
self._bcast(self.local_array, index, value)
207+
self._bcast(self.base_comm, self.base_comm_nccl,
208+
self.rank, self.local_array,
209+
index, value, engine=self.engine)
208210
else:
209211
self.local_array[index] = value
210212

@@ -380,7 +382,10 @@ def asarray(self, masked: bool = False):
380382
else:
381383
# Gather all the local arrays and apply concatenation.
382384
if masked:
383-
final_array = self._allgather_subcomm(self.local_array)
385+
final_array = self._allgather_subcomm(self.sub_comm,
386+
self.base_comm_nccl,
387+
self.local_array,
388+
engine=self.engine)
384389
else:
385390
final_array = self._allgather(self.base_comm,
386391
self.base_comm_nccl,
@@ -481,7 +486,9 @@ def _nccl_local_shapes(self, masked: bool):
481486
"""
482487
# gather tuple of shapes from every rank within thee communicator and copy from GPU to CPU
483488
if masked:
484-
all_tuples = self._allgather_subcomm(self.local_shape).get()
489+
all_tuples = self._allgather_subcomm(self.sub_comm,
490+
self.base_comm_nccl,
491+
self.local_shape).get()
485492
else:
486493
all_tuples = self._allgather(self.base_comm,
487494
self.base_comm_nccl,
@@ -799,7 +806,9 @@ def add_ghost_cells(self, cells_front: Optional[int] = None,
799806
f"{self.local_shape[self.axis]} < {cells_front}; "
800807
f"to achieve this use NUM_PROCESSES <= "
801808
f"{max(1, self.global_shape[self.axis] // cells_front)}")
802-
self._send(send_buf, dest=self.rank + 1, tag=1)
809+
self._send(self.base_comm, self.base_comm_nccl,
810+
send_buf, dest=self.rank + 1, tag=1,
811+
engine=self.engine)
803812
if cells_back is not None:
804813
total_cells_back = self.base_comm.allgather(cells_back) + [0]
805814
# Read cells_back which needs to be sent to rank - 1(cells_back for rank - 1)
@@ -814,7 +823,9 @@ def add_ghost_cells(self, cells_front: Optional[int] = None,
814823
f"{self.local_shape[self.axis]} < {cells_back}; "
815824
f"to achieve this use NUM_PROCESSES <= "
816825
f"{max(1, self.global_shape[self.axis] // cells_back)}")
817-
self._send(send_buf, dest=self.rank - 1, tag=0)
826+
self._send(self.base_comm, self.base_comm_nccl,
827+
send_buf, dest=self.rank - 1, tag=0,
828+
engine=self.engine)
818829
if self.rank != self.size - 1:
819830
recv_shape = list(recv_shapes[self.rank + 1])
820831
recv_shape[self.axis] = total_cells_back[self.rank]

0 commit comments

Comments
 (0)