Skip to content

Commit 58f1305

Browse files
committed
point-to-point (send/recv) using NCCL. Testsed with BlockDiag & FirstDerivative
1 parent 3848408 commit 58f1305

File tree

6 files changed

+248
-38
lines changed

6 files changed

+248
-38
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
nccl_message = deps.nccl_import("the DistributedArray module")
1515

1616
if nccl_message is None and cupy_message is None:
17-
from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split
17+
from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv
1818
from cupy.cuda.nccl import NcclCommunicator
1919
else:
2020
NcclCommunicator = Any
@@ -503,6 +503,35 @@ def _allgather(self, send_buf, recv_buf=None):
503503
self.base_comm.Allgather(send_buf, recv_buf)
504504
return recv_buf
505505

506+
def _send(self, send_buf, dest, count=None, tag=None):
507+
""" Send operation
508+
"""
509+
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
510+
if count is None:
511+
# assuming sending the whole array
512+
count = send_buf.size
513+
nccl_send(self.base_comm_nccl, send_buf, dest, count)
514+
else:
515+
self.base_comm.Send(send_buf, dest, tag)
516+
517+
def _recv(self, recv_buf=None, source=0, count=None, tag=None):
518+
""" Receive operation
519+
"""
520+
# NCCL must be called with recv_buf. Size cannot be inferred from
521+
# other arguments and thus cannot dynamically allocated
522+
if deps.nccl_enabled and getattr(self, "base_comm_nccl") and recv_buf is not None:
523+
if count is None:
524+
# assuming data will take a space of the whole buffer
525+
count = recv_buf.size
526+
nccl_recv(self.base_comm_nccl, recv_buf, source, count)
527+
return recv_buf
528+
else:
529+
# MPI allows a receiver buffer to be optional
530+
if recv_buf is None:
531+
return self.base_comm.recv(source=source, tag=tag)
532+
self.base_comm.Recv(buf=recv_buf, source=source, tag=tag)
533+
return recv_buf
534+
506535
def __neg__(self):
507536
arr = DistributedArray(global_shape=self.global_shape,
508537
base_comm=self.base_comm,
@@ -747,50 +776,55 @@ def add_ghost_cells(self, cells_front: Optional[int] = None,
747776
748777
"""
749778
ghosted_array = self.local_array.copy()
779+
ncp = get_module(getattr(self, "engine"))
750780
if cells_front is not None:
751-
# TODO: these are metadata (small size). Under current API, it will
752-
# call nccl allgather, should we force it to always use MPI?
753-
cells_fronts = self._allgather(cells_front)
754-
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
755-
total_cells_front = cells_fronts.tolist() + [0]
756-
else:
757-
total_cells_front = cells_fronts + [0]
781+
# cells_front is small array of int. Explicitly use MPI
782+
total_cells_front = self.base_comm.allgather(cells_front) + [0]
758783
# Read cells_front which needs to be sent to rank + 1(cells_front for rank + 1)
759784
cells_front = total_cells_front[self.rank + 1]
785+
send_buf = ncp.take(self.local_array, ncp.arange(-cells_front, 0), axis=self.axis)
786+
recv_shapes = self.local_shapes
760787
if self.rank != 0:
761-
ghosted_array = np.concatenate([self.base_comm.recv(source=self.rank - 1, tag=1), ghosted_array],
762-
axis=self.axis)
763-
if self.rank != self.size - 1:
788+
# from receiver's perspective (rank), the recv buffer have the same shape as the sender's array (rank-1)
789+
# in every dimension except the shape at axis=self.axis
790+
recv_shape = list(recv_shapes[self.rank - 1])
791+
recv_shape[self.axis] = total_cells_front[self.rank]
792+
recv_buf = ncp.zeros(recv_shape, dtype=ghosted_array.dtype)
793+
# Some communication can skip if len(recv_buf) = 0
794+
# Additionally, NCCL will hang if the buffer size is 0 so this optimization is somewhat mandatory
795+
if len(recv_buf) != 0:
796+
ghosted_array = ncp.concatenate([self._recv(recv_buf, source=self.rank - 1, tag=1), ghosted_array], axis=self.axis)
797+
# The skip in sender is to match with what described in receiver
798+
if self.rank != self.size - 1 and len(send_buf) != 0:
764799
if cells_front > self.local_shape[self.axis]:
765800
raise ValueError(f"Local Shape at rank={self.rank} along axis={self.axis} "
766801
f"should be > {cells_front}: dim({self.axis}) "
767802
f"{self.local_shape[self.axis]} < {cells_front}; "
768803
f"to achieve this use NUM_PROCESSES <= "
769804
f"{max(1, self.global_shape[self.axis] // cells_front)}")
770-
# TODO: this array maybe large. Currently it will always use MPI.
771-
# Should we enable NCCL point-point here ?
772-
self.base_comm.send(np.take(self.local_array, np.arange(-cells_front, 0), axis=self.axis),
773-
dest=self.rank + 1, tag=1)
805+
self._send(send_buf, dest=self.rank + 1, tag=1)
774806
if cells_back is not None:
775-
cells_backs = self._allgather(cells_back)
776-
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
777-
total_cells_back = cells_backs.tolist() + [0]
778-
else:
779-
total_cells_back = cells_backs + [0]
807+
total_cells_back = self.base_comm.allgather(cells_back) + [0]
780808
# Read cells_back which needs to be sent to rank - 1(cells_back for rank - 1)
781809
cells_back = total_cells_back[self.rank - 1]
782-
if self.rank != 0:
810+
send_buf = ncp.take(self.local_array, ncp.arange(cells_back), axis=self.axis)
811+
# Same reasoning as sending cell front applied
812+
recv_shapes = self.local_shapes
813+
if self.rank != 0 and len(send_buf) != 0:
783814
if cells_back > self.local_shape[self.axis]:
784815
raise ValueError(f"Local Shape at rank={self.rank} along axis={self.axis} "
785816
f"should be > {cells_back}: dim({self.axis}) "
786817
f"{self.local_shape[self.axis]} < {cells_back}; "
787818
f"to achieve this use NUM_PROCESSES <= "
788819
f"{max(1, self.global_shape[self.axis] // cells_back)}")
789-
self.base_comm.send(np.take(self.local_array, np.arange(cells_back), axis=self.axis),
790-
dest=self.rank - 1, tag=0)
820+
self._send(send_buf, dest=self.rank - 1, tag=0)
791821
if self.rank != self.size - 1:
792-
ghosted_array = np.append(ghosted_array, self.base_comm.recv(source=self.rank + 1, tag=0),
793-
axis=self.axis)
822+
recv_shape = list(recv_shapes[self.rank + 1])
823+
recv_shape[self.axis] = total_cells_back[self.rank]
824+
recv_buf = ncp.zeros(recv_shape, dtype=ghosted_array.dtype)
825+
if len(recv_buf) != 0:
826+
ghosted_array = ncp.append(ghosted_array, self._recv(recv_buf, source=self.rank + 1, tag=0),
827+
axis=self.axis)
794828
return ghosted_array
795829

796830
def __repr__(self):

pylops_mpi/LinearOperator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
8686
if self.Op:
8787
y = DistributedArray(global_shape=self.shape[0],
8888
base_comm=self.base_comm,
89+
base_comm_nccl=x.base_comm_nccl,
8990
partition=x.partition,
9091
axis=x.axis,
9192
engine=x.engine,
@@ -123,6 +124,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
123124
if self.Op:
124125
y = DistributedArray(global_shape=self.shape[1],
125126
base_comm=self.base_comm,
127+
base_comm_nccl=x.base_comm_nccl,
126128
partition=x.partition,
127129
axis=x.axis,
128130
engine=x.engine,

pylops_mpi/basicoperators/BlockDiag.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def __init__(self, ops: Sequence[LinearOperator],
121121
@reshaped(forward=True, stacking=True)
122122
def _matvec(self, x: DistributedArray) -> DistributedArray:
123123
ncp = get_module(x.engine)
124-
y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n,
124+
y = DistributedArray(global_shape=self.shape[0], base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_n,
125125
mask=self.mask, engine=x.engine, dtype=self.dtype)
126126
y1 = []
127127
for iop, oper in enumerate(self.ops):
@@ -133,7 +133,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
133133
@reshaped(forward=False, stacking=True)
134134
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
135135
ncp = get_module(x.engine)
136-
y = DistributedArray(global_shape=self.shape[1], local_shapes=self.local_shapes_m,
136+
y = DistributedArray(global_shape=self.shape[1], base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_m,
137137
mask=self.mask, engine=x.engine, dtype=self.dtype)
138138
y1 = []
139139
for iop, oper in enumerate(self.ops):

pylops_mpi/basicoperators/FirstDerivative.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -129,19 +129,19 @@ def _register_multiplications(
129129
def _matvec(self, x: DistributedArray) -> DistributedArray:
130130
# If Partition.BROADCAST, then convert to Partition.SCATTER
131131
if x.partition is Partition.BROADCAST:
132-
x = DistributedArray.to_dist(x=x.local_array)
132+
x = DistributedArray.to_dist(x=x.local_array, base_comm_nccl=x.base_comm_nccl)
133133
return self._hmatvec(x)
134134

135135
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
136136
# If Partition.BROADCAST, then convert to Partition.SCATTER
137137
if x.partition is Partition.BROADCAST:
138-
x = DistributedArray.to_dist(x=x.local_array)
138+
x = DistributedArray.to_dist(x=x.local_array, base_comm_nccl=x.base_comm_nccl)
139139
return self._hrmatvec(x)
140140

141141
@reshaped
142142
def _matvec_forward(self, x: DistributedArray) -> DistributedArray:
143143
ncp = get_module(x.engine)
144-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
144+
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
145145
axis=x.axis, engine=x.engine, dtype=self.dtype)
146146
ghosted_x = x.add_ghost_cells(cells_back=1)
147147
y_forward = ghosted_x[1:] - ghosted_x[:-1]
@@ -153,7 +153,7 @@ def _matvec_forward(self, x: DistributedArray) -> DistributedArray:
153153
@reshaped
154154
def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray:
155155
ncp = get_module(x.engine)
156-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
156+
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
157157
axis=x.axis, engine=x.engine, dtype=self.dtype)
158158
y[:] = 0
159159
if self.rank == self.size - 1:
@@ -171,7 +171,7 @@ def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray:
171171
@reshaped
172172
def _matvec_backward(self, x: DistributedArray) -> DistributedArray:
173173
ncp = get_module(x.engine)
174-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
174+
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
175175
axis=x.axis, engine=x.engine, dtype=self.dtype)
176176
ghosted_x = x.add_ghost_cells(cells_front=1)
177177
y_backward = ghosted_x[1:] - ghosted_x[:-1]
@@ -183,7 +183,7 @@ def _matvec_backward(self, x: DistributedArray) -> DistributedArray:
183183
@reshaped
184184
def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray:
185185
ncp = get_module(x.engine)
186-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
186+
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
187187
axis=x.axis, engine=x.engine, dtype=self.dtype)
188188
y[:] = 0
189189
ghosted_x = x.add_ghost_cells(cells_back=1)
@@ -201,7 +201,7 @@ def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray:
201201
@reshaped
202202
def _matvec_centered3(self, x: DistributedArray) -> DistributedArray:
203203
ncp = get_module(x.engine)
204-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
204+
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
205205
axis=x.axis, engine=x.engine, dtype=self.dtype)
206206
ghosted_x = x.add_ghost_cells(cells_front=1, cells_back=1)
207207
y_centered = 0.5 * (ghosted_x[2:] - ghosted_x[:-2])
@@ -221,7 +221,7 @@ def _matvec_centered3(self, x: DistributedArray) -> DistributedArray:
221221
@reshaped
222222
def _rmatvec_centered3(self, x: DistributedArray) -> DistributedArray:
223223
ncp = get_module(x.engine)
224-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
224+
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
225225
axis=x.axis, engine=x.engine, dtype=self.dtype)
226226
y[:] = 0
227227

@@ -249,7 +249,7 @@ def _rmatvec_centered3(self, x: DistributedArray) -> DistributedArray:
249249
@reshaped
250250
def _matvec_centered5(self, x: DistributedArray) -> DistributedArray:
251251
ncp = get_module(x.engine)
252-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
252+
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
253253
axis=x.axis, engine=x.engine, dtype=self.dtype)
254254
ghosted_x = x.add_ghost_cells(cells_front=2, cells_back=2)
255255
y_centered = (
@@ -276,7 +276,7 @@ def _matvec_centered5(self, x: DistributedArray) -> DistributedArray:
276276
@reshaped
277277
def _rmatvec_centered5(self, x: DistributedArray) -> DistributedArray:
278278
ncp = get_module(x.engine)
279-
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
279+
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
280280
axis=x.axis, engine=x.engine, dtype=self.dtype)
281281
y[:] = 0
282282
ghosted_x = x.add_ghost_cells(cells_back=4)

pylops_mpi/utils/_nccl.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
"nccl_allgather",
55
"nccl_allreduce",
66
"nccl_bcast",
7-
"nccl_asarray"
7+
"nccl_asarray",
8+
"nccl_send",
9+
"nccl_recv"
810
]
911

1012
from enum import IntEnum
@@ -286,3 +288,57 @@ def nccl_asarray(nccl_comm, local_array, local_shapes, axis) -> cp.ndarray:
286288
chunks[i] = chunks[i].reshape(send_shape)[slicing]
287289
# combine back to single global array
288290
return cp.concatenate(chunks, axis=axis)
291+
292+
293+
def nccl_send(nccl_comm, send_buf, dest, count):
294+
"""NCCL equivalent of MPI_Send. Sends a specified number of elements
295+
from the buffer to a destination GPU device.
296+
297+
Parameters
298+
----------
299+
nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`
300+
The NCCL communicator used for point-to-point communication.
301+
send_buf : :obj:`cupy.ndarray`
302+
The array containing data to send.
303+
dest: :obj:`int`
304+
The rank of the destination GPU device.
305+
count : :obj:`int`
306+
Number of elements to send from `send_buf`.
307+
308+
Returns
309+
-------
310+
None
311+
"""
312+
nccl_comm.send(send_buf.data.ptr,
313+
count,
314+
cupy_to_nccl_dtype[str(send_buf.dtype)],
315+
dest,
316+
cp.cuda.Stream.null.ptr
317+
)
318+
319+
320+
def nccl_recv(nccl_comm, recv_buf, source, count=None):
321+
"""NCCL equivalent of MPI_Recv. Receives data from a source GPU device
322+
into the given buffer.
323+
324+
Parameters
325+
----------
326+
nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`
327+
The NCCL communicator used for point-to-point communication.
328+
recv_buf : :obj:`cupy.ndarray`
329+
The array to store the received data.
330+
source : :obj:`int`
331+
The rank of the source GPU device.
332+
count : :obj:`int`, optional
333+
Number of elements to receive.
334+
335+
Returns
336+
-------
337+
None
338+
"""
339+
nccl_comm.recv(recv_buf.data.ptr,
340+
count,
341+
cupy_to_nccl_dtype[str(recv_buf.dtype)],
342+
source,
343+
cp.cuda.Stream.null.ptr
344+
)

0 commit comments

Comments
 (0)