Skip to content

Commit dbe1f30

Browse files
committed
MixIn for allgather.
1 parent ab97e3d commit dbe1f30

File tree

5 files changed

+163
-161
lines changed

5 files changed

+163
-161
lines changed

pylops_mpi/Distributed.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from typing import Any, NewType
1+
from typing import Any, NewType, Tuple
22

33
from mpi4py import MPI
44
from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils
5-
from pylops_mpi.utils._mpi import mpi_allreduce, mpi_send
5+
from pylops_mpi.utils._mpi import mpi_allreduce, mpi_allgather, mpi_send, mpi_recv, _prepare_allgather_inputs, _unroll_allgather_recv
66
from pylops_mpi.utils import deps
77

88
cupy_message = pylops_deps.cupy_import("the DistributedArray module")
@@ -11,11 +11,9 @@
1111
if nccl_message is None and cupy_message is None:
1212
from pylops_mpi.utils._nccl import (
1313
nccl_allgather, nccl_allreduce,
14-
nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv,
15-
_prepare_nccl_allgather_inputs, _unroll_nccl_allgather_recv
14+
nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv
1615
)
1716

18-
1917
class DistributedMixIn:
2018
r"""Distributed Mixin class
2119
@@ -44,6 +42,36 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
4442
return mpi_allreduce(self.sub_comm, send_buf,
4543
recv_buf, self.engine, op)
4644

45+
def _allgather(self, send_buf, recv_buf=None):
46+
"""Allgather operation
47+
"""
48+
if deps.nccl_enabled and self.base_comm_nccl:
49+
if isinstance(send_buf, (tuple, list, int)):
50+
return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf)
51+
else:
52+
send_shapes = self.base_comm.allgather(send_buf.shape)
53+
(padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine="cupy")
54+
raw_recv = nccl_allgather(self.base_comm_nccl, padded_send, recv_buf if recv_buf else padded_recv)
55+
return _unroll_allgather_recv(raw_recv, padded_send.shape, send_shapes)
56+
else:
57+
if isinstance(send_buf, (tuple, list, int)):
58+
return self.base_comm.allgather(send_buf)
59+
return mpi_allgather(self.base_comm, send_buf, recv_buf, self.engine)
60+
61+
def _allgather_subcomm(self, send_buf, recv_buf=None):
62+
"""Allgather operation with subcommunicator
63+
"""
64+
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
65+
if isinstance(send_buf, (tuple, list, int)):
66+
return nccl_allgather(self.sub_comm, send_buf, recv_buf)
67+
else:
68+
send_shapes = self._allgather_subcomm(send_buf.shape)
69+
(padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine="cupy")
70+
raw_recv = nccl_allgather(self.sub_comm, padded_send, recv_buf if recv_buf else padded_recv)
71+
return _unroll_allgather_recv(raw_recv, padded_send.shape, send_shapes)
72+
else:
73+
return mpi_allgather(self.sub_comm, send_buf, recv_buf, self.engine)
74+
4775
def _send(self, send_buf, dest, count=None, tag=0):
4876
"""Send operation
4977
"""
@@ -55,3 +83,20 @@ def _send(self, send_buf, dest, count=None, tag=0):
5583
mpi_send(self.base_comm,
5684
send_buf, dest, count, tag=tag,
5785
engine=self.engine)
86+
87+
def _recv(self, recv_buf=None, source=0, count=None, tag=0):
88+
"""Receive operation
89+
"""
90+
if deps.nccl_enabled and self.base_comm_nccl:
91+
if recv_buf is None:
92+
raise ValueError("recv_buf must be supplied when using NCCL")
93+
if count is None:
94+
count = recv_buf.size
95+
nccl_recv(self.base_comm_nccl, recv_buf, source, count)
96+
return recv_buf
97+
else:
98+
return mpi_recv(self.base_comm,
99+
recv_buf, source, count, tag=tag,
100+
engine=self.engine)
101+
102+

pylops_mpi/DistributedArray.py

Lines changed: 1 addition & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,13 @@
99
from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils
1010
from pylops.utils._internal import _value_or_sized_to_tuple
1111
from pylops.utils.backend import get_array_module, get_module, get_module_name
12-
from pylops_mpi.utils._mpi import mpi_allreduce, mpi_send
1312
from pylops_mpi.utils import deps
1413

1514
cupy_message = pylops_deps.cupy_import("the DistributedArray module")
1615
nccl_message = deps.nccl_import("the DistributedArray module")
1716

1817
if nccl_message is None and cupy_message is None:
19-
from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv, _prepare_nccl_allgather_inputs, _unroll_nccl_allgather_recv
18+
from pylops_mpi.utils._nccl import nccl_asarray, nccl_bcast, nccl_split
2019
from cupy.cuda.nccl import NcclCommunicator
2120
else:
2221
NcclCommunicator = Any
@@ -474,67 +473,6 @@ def _check_mask(self, dist_array):
474473
if not np.array_equal(self.mask, dist_array.mask):
475474
raise ValueError("Mask of both the arrays must be same")
476475

477-
def _allgather(self, send_buf, recv_buf=None):
478-
"""Allgather operation
479-
"""
480-
if deps.nccl_enabled and self.base_comm_nccl:
481-
if isinstance(send_buf, (tuple, list, int)):
482-
return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf)
483-
else:
484-
send_shapes = self.base_comm.allgather(send_buf.shape)
485-
(padded_send, padded_recv) = _prepare_nccl_allgather_inputs(send_buf, send_shapes)
486-
raw_recv = nccl_allgather(self.base_comm_nccl, padded_send, recv_buf if recv_buf else padded_recv)
487-
return _unroll_nccl_allgather_recv(raw_recv, padded_send.shape, send_shapes)
488-
else:
489-
if recv_buf is None:
490-
return self.base_comm.allgather(send_buf)
491-
self.base_comm.Allgather(send_buf, recv_buf)
492-
return recv_buf
493-
494-
def _allgather_subcomm(self, send_buf, recv_buf=None):
495-
"""Allgather operation with subcommunicator
496-
"""
497-
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
498-
if isinstance(send_buf, (tuple, list, int)):
499-
return nccl_allgather(self.sub_comm, send_buf, recv_buf)
500-
else:
501-
send_shapes = self._allgather_subcomm(send_buf.shape)
502-
(padded_send, padded_recv) = _prepare_nccl_allgather_inputs(send_buf, send_shapes)
503-
raw_recv = nccl_allgather(self.sub_comm, padded_send, recv_buf if recv_buf else padded_recv)
504-
return _unroll_nccl_allgather_recv(raw_recv, padded_send.shape, send_shapes)
505-
else:
506-
if recv_buf is None:
507-
return self.sub_comm.allgather(send_buf)
508-
self.sub_comm.Allgather(send_buf, recv_buf)
509-
510-
def _recv(self, recv_buf=None, source=0, count=None, tag=0):
511-
"""Receive operation
512-
"""
513-
if deps.nccl_enabled and self.base_comm_nccl:
514-
if recv_buf is None:
515-
raise ValueError("recv_buf must be supplied when using NCCL")
516-
if count is None:
517-
count = recv_buf.size
518-
nccl_recv(self.base_comm_nccl, recv_buf, source, count)
519-
return recv_buf
520-
else:
521-
# NumPy + MPI will benefit from buffered communication regardless of MPI installation
522-
if deps.cuda_aware_mpi_enabled or self.engine == "numpy":
523-
ncp = get_module(self.engine)
524-
if recv_buf is None:
525-
if count is None:
526-
raise ValueError("Must provide either recv_buf or count for MPI receive")
527-
# Default to int32 works currently because add_ghost_cells() is called
528-
# with recv_buf and is not affected by this branch. The int32 is for when
529-
# dimension or shape-related integers are send/recv
530-
recv_buf = ncp.zeros(count, dtype=ncp.int32)
531-
mpi_type = MPI._typedict[recv_buf.dtype.char]
532-
self.base_comm.Recv([recv_buf, recv_buf.size, mpi_type], source=source, tag=tag)
533-
else:
534-
# Uses CuPy without CUDA-aware MPI
535-
recv_buf = self.base_comm.recv(source=source, tag=tag)
536-
return recv_buf
537-
538476
def _nccl_local_shapes(self, masked: bool):
539477
"""Get the the list of shapes of every GPU in the communicator
540478
"""

pylops_mpi/utils/_mpi.py

Lines changed: 105 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,100 @@
11
__all__ = [
2-
# "mpi_allgather",
2+
"mpi_allgather",
33
"mpi_allreduce",
44
# "mpi_bcast",
55
# "mpi_asarray",
66
"mpi_send",
7-
# "mpi_recv",
7+
"mpi_recv",
8+
"_prepare_allgather_inputs",
9+
"_unroll_allgather_recv"
810
]
911

10-
from typing import Optional
12+
from typing import Optional, Tuple
1113

1214
import numpy as np
1315
from mpi4py import MPI
1416
from pylops.utils.backend import get_module
1517
from pylops_mpi.utils import deps
1618

19+
# TODO: return type annotation for both cupy and numpy
20+
def _prepare_allgather_inputs(send_buf, send_buf_shapes, engine):
21+
r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather)
22+
23+
Buffered Allgather (MPI and NCCL) requires the sending buffer to have the same size for every device.
24+
Therefore, padding is required when the array is not evenly partitioned across
25+
all the ranks. The padding is applied such that the each dimension of the sending buffers
26+
is equal to the max size of that dimension across all ranks.
27+
28+
Similarly, each receiver buffer (recv_buf) is created with size equal to :math:n_rank \cdot send_buf.size
29+
30+
Parameters
31+
----------
32+
send_buf : :obj: `numpy.ndarray` or `cupy.ndarray` or array-like
33+
The data buffer from the local GPU to be sent for allgather.
34+
send_buf_shapes: :obj:`list`
35+
A list of shapes for each GPU send_buf (used to calculate padding size)
36+
engine : :obj:`str`
37+
Engine used to store array (``numpy`` or ``cupy``)
38+
39+
Returns
40+
-------
41+
send_buf: :obj:`cupy.ndarray`
42+
A buffer containing the data and padded elements to be sent by this rank.
43+
recv_buf : :obj:`cupy.ndarray`
44+
An empty, padded buffer to gather data from all GPUs.
45+
"""
46+
ncp = get_module(engine)
47+
sizes_each_dim = list(zip(*send_buf_shapes))
48+
send_shape = tuple(map(max, sizes_each_dim))
49+
pad_size = [
50+
(0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, send_buf.shape)
51+
]
52+
53+
send_buf = ncp.pad(
54+
send_buf, pad_size, mode="constant", constant_values=0
55+
)
56+
57+
ndev = len(send_buf_shapes)
58+
recv_buf = ncp.zeros(ndev * send_buf.size, dtype=send_buf.dtype)
59+
60+
return send_buf, recv_buf
61+
62+
63+
def _unroll_allgather_recv(recv_buf, padded_send_buf_shape, send_buf_shapes) -> list:
64+
r"""Unrolll recv_buf after Buffered Allgather (MPI and NCCL)
65+
66+
Remove the padded elements in recv_buff, extract an individual array from each device and return them as a list of arrays
67+
Each GPU may send array with a different shape, so the return type has to be a list of array
68+
instead of the concatenated array.
69+
70+
Parameters
71+
----------
72+
recv_buf: :obj:`cupy.ndarray` or array-like
73+
The data buffer returned from nccl_allgather call
74+
padded_send_buf_shape: :obj:`tuple`:int
75+
The size of send_buf after padding used in nccl_allgather
76+
send_buf_shapes: :obj:`list`
77+
A list of original shapes for each GPU send_buf prior to padding
78+
79+
Returns
80+
-------
81+
chunks: :obj:`list`
82+
A list of `cupy.ndarray` from each GPU with the padded element removed
83+
"""
84+
ndev = len(send_buf_shapes)
85+
# extract an individual array from each device
86+
chunk_size = np.prod(padded_send_buf_shape)
87+
chunks = [
88+
recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev)
89+
]
90+
91+
# Remove padding from each array: the padded value may appear somewhere
92+
# in the middle of the flat array and thus the reshape and slicing for each dimension is required
93+
for i in range(ndev):
94+
slicing = tuple(slice(0, end) for end in send_buf_shapes[i])
95+
chunks[i] = chunks[i].reshape(padded_send_buf_shape)[slicing]
96+
97+
return chunks
1798

1899
def mpi_allreduce(base_comm: MPI.Comm,
19100
send_buf, recv_buf=None,
@@ -57,7 +138,27 @@ def mpi_allreduce(base_comm: MPI.Comm,
57138
# For MIN and MAX which require recv_buf
58139
base_comm.Allreduce(send_buf, recv_buf, op)
59140
return recv_buf
60-
141+
142+
143+
def mpi_allgather(base_comm: MPI.Comm,
144+
send_buf, recv_buf=None,
145+
engine: Optional[str] = "numpy",
146+
) -> np.ndarray:
147+
148+
if deps.cuda_aware_mpi_enabled or engine == "numpy":
149+
send_shapes = base_comm.allgather(send_buf.shape)
150+
(padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine=engine)
151+
recv_buffer_to_use = recv_buf if recv_buf else padded_recv
152+
base_comm.Allgather(padded_send, recv_buffer_to_use)
153+
return _unroll_allgather_recv(recv_buffer_to_use, padded_send.shape, send_shapes)
154+
155+
else:
156+
# CuPy with non-CUDA-aware MPI
157+
if recv_buf is None:
158+
return base_comm.allgather(send_buf)
159+
base_comm.Allgather(send_buf, recv_buf)
160+
return recv_buf
161+
61162

62163
def mpi_send(base_comm: MPI.Comm,
63164
send_buf, dest, count, tag=0,

0 commit comments

Comments
 (0)