Skip to content

Commit 3eceb29

Browse files
committed
helper functions for nccl_allgather, simplify Fredholm1
1 parent 95e32bf commit 3eceb29

File tree

3 files changed

+97
-43
lines changed

3 files changed

+97
-43
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
nccl_message = deps.nccl_import("the DistributedArray module")
1515

1616
if nccl_message is None and cupy_message is None:
17-
from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv
17+
from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv, _prepare_nccl_allgather_inputs, _unroll_nccl_allgather_recv
1818
from cupy.cuda.nccl import NcclCommunicator
1919
else:
2020
NcclCommunicator = Any
@@ -500,7 +500,15 @@ def _allgather(self, send_buf, recv_buf=None):
500500
"""Allgather operation
501501
"""
502502
if deps.nccl_enabled and self.base_comm_nccl:
503-
return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf)
503+
if hasattr(send_buf, "shape"):
504+
send_shapes = self.base_comm.allgather(send_buf.shape)
505+
(padded_send, padded_recv) = _prepare_nccl_allgather_inputs(send_buf, send_shapes)
506+
# TODO: Should we ignore recv_buf completely in this case ?
507+
raw_recv = nccl_allgather(self.base_comm_nccl, padded_send, recv_buf if recv_buf else padded_recv)
508+
return _unroll_nccl_allgather_recv(raw_recv, padded_send.shape, send_shapes)
509+
else:
510+
# still works for a send_buf whose type is a tuple for _nccl_local_shapes
511+
return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf)
504512
else:
505513
if recv_buf is None:
506514
return self.base_comm.allgather(send_buf)
@@ -511,7 +519,13 @@ def _allgather_subcomm(self, send_buf, recv_buf=None):
511519
"""Allgather operation with subcommunicator
512520
"""
513521
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
514-
return nccl_allgather(self.sub_comm, send_buf, recv_buf)
522+
if hasattr(send_buf, "shape"):
523+
send_shapes = self.base_comm.allgather(send_buf.shape)
524+
(padded_send, padded_recv) = _prepare_nccl_allgather_inputs(send_buf, send_shapes)
525+
raw_recv = nccl_allgather(self.sub_comm, padded_send, recv_buf if recv_buf else padded_recv)
526+
return _unroll_nccl_allgather_recv(raw_recv, padded_send.shape, send_shapes)
527+
else:
528+
return nccl_allgather(self.sub_comm, send_buf, recv_buf)
515529
else:
516530
if recv_buf is None:
517531
return self.sub_comm.allgather(send_buf)

pylops_mpi/signalprocessing/Fredholm1.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -165,15 +165,5 @@ def _rmatvec(self, x: NDArray) -> NDArray:
165165
y1[isl] = ncp.dot(x[isl].T.conj(), self.G[isl]).T.conj()
166166

167167
# gather results
168-
recv = y._allgather(y1)
169-
# TODO: current of _allgather will call non-buffered MPI-AllGather (sub-optimal for CuPy+MPI)
170-
# which returns a list (not flatten) and does not require unrolling
171-
if self.usematmul and isinstance(recv, ncp.ndarray) :
172-
# unrolling
173-
chunk_size = self.ny * self.nz
174-
num_partition = (len(recv) + chunk_size - 1) // chunk_size
175-
recv = ncp.vstack([recv[i * chunk_size: (i + 1) * chunk_size].reshape(self.nz, self.ny).T for i in range(num_partition)])
176-
else:
177-
recv = ncp.vstack(recv)
178-
y[:] = recv.ravel()
168+
y[:] = ncp.vstack(y._allgather(y1)).ravel()
179169
return y

pylops_mpi/utils/_nccl.py

Lines changed: 79 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
"nccl_bcast",
77
"nccl_asarray",
88
"nccl_send",
9-
"nccl_recv"
9+
"nccl_recv",
10+
"_prepare_nccl_allgather_inputs",
11+
"_unroll_nccl_allgather_recv"
1012
]
1113

1214
from enum import IntEnum
@@ -251,61 +253,109 @@ def nccl_bcast(nccl_comm, local_array, index, value) -> None:
251253
)
252254

253255

254-
def nccl_asarray(nccl_comm, local_array, local_shapes, axis) -> cp.ndarray:
255-
"""Global view of the array
256+
def _prepare_nccl_allgather_inputs(send_buf, send_buf_shapes) -> tuple[cp.ndarray, cp.ndarray]:
257+
""" Preparing the send_buf and recv_buf for the NCCL allgather (nccl_allgather)
256258
257-
Gather all local GPU arrays into a single global array via NCCL all-gather.
259+
NCCL's allGather requires the sending buffer to have the same size for every device.
260+
Therefore, the padding is required when the array is not evenly partitioned across
261+
all the ranks. The padding is applied such that the sending buffer has the size of
262+
each dimension corresponding to the max possible size of that dimension.
263+
264+
Receiver buff (recv_buf) will have the size n_rank * send_buf.size
258265
259266
Parameters
260267
----------
261-
nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`
262-
The NCCL communicator used for collective communication.
263-
local_array : :obj:`cupy.ndarray`
264-
The local array on the current GPU.
265-
local_shapes : :obj:`list`
266-
A list of shapes for each GPU local array (used to trim padding).
267-
axis : :obj:`int`
268-
The axis along which to concatenate the gathered arrays.
268+
send_buf : :obj:`cupy.ndarray` or array-like
269+
The data buffer from the local GPU to be sent for allgather.
270+
send_buf_shapes: :obj:`list`
271+
A list of shapes for each GPU send_buf (used to calculate padding size)
269272
270273
Returns
271274
-------
272-
final_array : :obj:`cupy.ndarray`
273-
Global array gathered from all GPUs and concatenated along `axis`.
275+
tuple[send_buf, recv_buf]: :obj:`tuple`
276+
A tuple of (send_buf, recv_buf) will an appropriate size, shape and dtype for NCCL allgather
274277
275-
Notes
276-
-----
277-
NCCL's allGather requires the sending buffer to have the same size for every device.
278-
Therefore, the padding is required when the array is not evenly partitioned across
279-
all the ranks. The padding is applied such that the sending buffer has the size of
280-
each dimension corresponding to the max possible size of that dimension.
281278
"""
282-
sizes_each_dim = list(zip(*local_shapes))
283-
279+
sizes_each_dim = list(zip(*send_buf_shapes))
284280
send_shape = tuple(map(max, sizes_each_dim))
285281
pad_size = [
286-
(0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, local_array.shape)
282+
(0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, send_buf.shape)
287283
]
288284

289285
send_buf = cp.pad(
290-
local_array, pad_size, mode="constant", constant_values=0
286+
send_buf, pad_size, mode="constant", constant_values=0
291287
)
292288

293289
# NCCL recommends to use one MPI Process per GPU and so size of receiving buffer can be inferred
294-
ndev = len(local_shapes)
290+
ndev = len(send_buf_shapes)
295291
recv_buf = cp.zeros(ndev * send_buf.size, dtype=send_buf.dtype)
296-
nccl_allgather(nccl_comm, send_buf, recv_buf)
297292

293+
return (send_buf, recv_buf)
294+
295+
296+
def _unroll_nccl_allgather_recv(recv_buf, padded_send_buf_shape, send_buf_shapes) -> list:
297+
""" Remove the padded elements in recv_buff, extract an individual array from each device and return them as a list of arrays
298+
299+
Each GPU may send array with a different shape, so the return type has to be a list of array
300+
instead of the concatenated array.
301+
302+
Parameters
303+
----------
304+
recv_buf: :obj:`cupy.ndarray` or array-like
305+
The data buffer returned from nccl_allgather call
306+
padded_send_buf_shape: :obj:`tuple`:int
307+
The size of send_buf after padding used in nccl_allgather
308+
send_buf_shapes: :obj:`list`
309+
A list of original shapes for each GPU send_buf prior to padding
310+
311+
Returns
312+
-------
313+
chunks: :obj:`list`
314+
A list of `cupy.ndarray` from each GPU with the padded element removed
315+
"""
316+
317+
ndev = len(send_buf_shapes)
298318
# extract an individual array from each device
299-
chunk_size = np.prod(send_shape)
319+
chunk_size = np.prod(padded_send_buf_shape)
300320
chunks = [
301321
recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev)
302322
]
303323

304324
# Remove padding from each array: the padded value may appear somewhere
305325
# in the middle of the flat array and thus the reshape and slicing for each dimension is required
306326
for i in range(ndev):
307-
slicing = tuple(slice(0, end) for end in local_shapes[i])
308-
chunks[i] = chunks[i].reshape(send_shape)[slicing]
327+
slicing = tuple(slice(0, end) for end in send_buf_shapes[i])
328+
chunks[i] = chunks[i].reshape(padded_send_buf_shape)[slicing]
329+
330+
return chunks
331+
332+
333+
def nccl_asarray(nccl_comm, local_array, local_shapes, axis) -> cp.ndarray:
334+
"""Global view of the array
335+
336+
Gather all local GPU arrays into a single global array via NCCL all-gather.
337+
338+
Parameters
339+
----------
340+
nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`
341+
The NCCL communicator used for collective communication.
342+
local_array : :obj:`cupy.ndarray`
343+
The local array on the current GPU.
344+
local_shapes : :obj:`list`
345+
A list of shapes for each GPU local array (used to trim padding).
346+
axis : :obj:`int`
347+
The axis along which to concatenate the gathered arrays.
348+
349+
Returns
350+
-------
351+
final_array : :obj:`cupy.ndarray`
352+
Global array gathered from all GPUs and concatenated along `axis`.
353+
"""
354+
355+
(send_buf, recv_buf) = _prepare_nccl_allgather_inputs(local_array, local_shapes)
356+
nccl_allgather(nccl_comm, send_buf, recv_buf)
357+
chunks = _unroll_nccl_allgather_recv(recv_buf, send_buf.shape, local_shapes)
358+
309359
# combine back to single global array
310360
return cp.concatenate(chunks, axis=axis)
311361

0 commit comments

Comments
 (0)