Skip to content

Commit a2a23b0

Browse files
committed
Fixes from PR commennts & add test for nccl_utils (PR#151)
1 parent 3eceb29 commit a2a23b0

File tree

3 files changed

+193
-89
lines changed

3 files changed

+193
-89
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -500,15 +500,13 @@ def _allgather(self, send_buf, recv_buf=None):
500500
"""Allgather operation
501501
"""
502502
if deps.nccl_enabled and self.base_comm_nccl:
503-
if hasattr(send_buf, "shape"):
503+
if isinstance(send_buf, (tuple, list, int)):
504+
return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf)
505+
else:
504506
send_shapes = self.base_comm.allgather(send_buf.shape)
505507
(padded_send, padded_recv) = _prepare_nccl_allgather_inputs(send_buf, send_shapes)
506-
# TODO: Should we ignore recv_buf completely in this case ?
507508
raw_recv = nccl_allgather(self.base_comm_nccl, padded_send, recv_buf if recv_buf else padded_recv)
508509
return _unroll_nccl_allgather_recv(raw_recv, padded_send.shape, send_shapes)
509-
else:
510-
# still works for a send_buf whose type is a tuple for _nccl_local_shapes
511-
return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf)
512510
else:
513511
if recv_buf is None:
514512
return self.base_comm.allgather(send_buf)
@@ -519,13 +517,13 @@ def _allgather_subcomm(self, send_buf, recv_buf=None):
519517
"""Allgather operation with subcommunicator
520518
"""
521519
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
522-
if hasattr(send_buf, "shape"):
523-
send_shapes = self.base_comm.allgather(send_buf.shape)
520+
if isinstance(send_buf, (tuple, list, int)):
521+
return nccl_allgather(self.sub_comm, send_buf, recv_buf)
522+
else:
523+
send_shapes = self._allgather_subcomm(send_buf.shape)
524524
(padded_send, padded_recv) = _prepare_nccl_allgather_inputs(send_buf, send_shapes)
525525
raw_recv = nccl_allgather(self.sub_comm, padded_send, recv_buf if recv_buf else padded_recv)
526526
return _unroll_nccl_allgather_recv(raw_recv, padded_send.shape, send_shapes)
527-
else:
528-
return nccl_allgather(self.sub_comm, send_buf, recv_buf)
529527
else:
530528
if recv_buf is None:
531529
return self.sub_comm.allgather(send_buf)

pylops_mpi/utils/_nccl.py

Lines changed: 84 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
__all__ = [
2+
"_prepare_nccl_allgather_inputs",
3+
"_unroll_nccl_allgather_recv",
24
"initialize_nccl_comm",
35
"nccl_split",
46
"nccl_allgather",
@@ -7,17 +9,17 @@
79
"nccl_asarray",
810
"nccl_send",
911
"nccl_recv",
10-
"_prepare_nccl_allgather_inputs",
11-
"_unroll_nccl_allgather_recv"
1212
]
1313

1414
from enum import IntEnum
15+
from typing import Tuple
1516
from mpi4py import MPI
1617
import os
1718
import numpy as np
1819
import cupy as cp
1920
import cupy.cuda.nccl as nccl
2021

22+
2123
cupy_to_nccl_dtype = {
2224
"float32": nccl.NCCL_FLOAT32,
2325
"float64": nccl.NCCL_FLOAT64,
@@ -61,6 +63,85 @@ def _nccl_buf_size(buf, count=None):
6163
return count if count else buf.size
6264

6365

66+
def _prepare_nccl_allgather_inputs(send_buf, send_buf_shapes) -> Tuple[cp.ndarray, cp.ndarray]:
67+
r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather)
68+
69+
NCCL's allGather requires the sending buffer to have the same size for every device.
70+
Therefore, padding is required when the array is not evenly partitioned across
71+
all the ranks. The padding is applied such that the each dimension of the sending buffers
72+
is equal to the max size of that dimension across all ranks.
73+
74+
Similarly, each receiver buffer (recv_buf) is created with size equal to :math:n_rank \cdot send_buf.size
75+
76+
Parameters
77+
----------
78+
send_buf : :obj:`cupy.ndarray` or array-like
79+
The data buffer from the local GPU to be sent for allgather.
80+
send_buf_shapes: :obj:`list`
81+
A list of shapes for each GPU send_buf (used to calculate padding size)
82+
83+
Returns
84+
-------
85+
send_buf: :obj:`cupy.ndarray`
86+
A buffer containing the data and padded elements to be sent by this rank.
87+
recv_buf : :obj:`cupy.ndarray`
88+
An empty, padded buffer to gather data from all GPUs.
89+
"""
90+
sizes_each_dim = list(zip(*send_buf_shapes))
91+
send_shape = tuple(map(max, sizes_each_dim))
92+
pad_size = [
93+
(0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, send_buf.shape)
94+
]
95+
96+
send_buf = cp.pad(
97+
send_buf, pad_size, mode="constant", constant_values=0
98+
)
99+
100+
# NCCL recommends to use one MPI Process per GPU and so size of receiving buffer can be inferred
101+
ndev = len(send_buf_shapes)
102+
recv_buf = cp.zeros(ndev * send_buf.size, dtype=send_buf.dtype)
103+
104+
return send_buf, recv_buf
105+
106+
107+
def _unroll_nccl_allgather_recv(recv_buf, padded_send_buf_shape, send_buf_shapes) -> list:
108+
"""Unrolll recv_buf after NCCL allgather (nccl_allgather)
109+
110+
Remove the padded elements in recv_buff, extract an individual array from each device and return them as a list of arrays
111+
Each GPU may send array with a different shape, so the return type has to be a list of array
112+
instead of the concatenated array.
113+
114+
Parameters
115+
----------
116+
recv_buf: :obj:`cupy.ndarray` or array-like
117+
The data buffer returned from nccl_allgather call
118+
padded_send_buf_shape: :obj:`tuple`:int
119+
The size of send_buf after padding used in nccl_allgather
120+
send_buf_shapes: :obj:`list`
121+
A list of original shapes for each GPU send_buf prior to padding
122+
123+
Returns
124+
-------
125+
chunks: :obj:`list`
126+
A list of `cupy.ndarray` from each GPU with the padded element removed
127+
"""
128+
129+
ndev = len(send_buf_shapes)
130+
# extract an individual array from each device
131+
chunk_size = np.prod(padded_send_buf_shape)
132+
chunks = [
133+
recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev)
134+
]
135+
136+
# Remove padding from each array: the padded value may appear somewhere
137+
# in the middle of the flat array and thus the reshape and slicing for each dimension is required
138+
for i in range(ndev):
139+
slicing = tuple(slice(0, end) for end in send_buf_shapes[i])
140+
chunks[i] = chunks[i].reshape(padded_send_buf_shape)[slicing]
141+
142+
return chunks
143+
144+
64145
def mpi_op_to_nccl(mpi_op) -> NcclOp:
65146
""" Map MPI reduction operation to NCCL equivalent
66147
@@ -253,83 +334,6 @@ def nccl_bcast(nccl_comm, local_array, index, value) -> None:
253334
)
254335

255336

256-
def _prepare_nccl_allgather_inputs(send_buf, send_buf_shapes) -> tuple[cp.ndarray, cp.ndarray]:
257-
""" Preparing the send_buf and recv_buf for the NCCL allgather (nccl_allgather)
258-
259-
NCCL's allGather requires the sending buffer to have the same size for every device.
260-
Therefore, the padding is required when the array is not evenly partitioned across
261-
all the ranks. The padding is applied such that the sending buffer has the size of
262-
each dimension corresponding to the max possible size of that dimension.
263-
264-
Receiver buff (recv_buf) will have the size n_rank * send_buf.size
265-
266-
Parameters
267-
----------
268-
send_buf : :obj:`cupy.ndarray` or array-like
269-
The data buffer from the local GPU to be sent for allgather.
270-
send_buf_shapes: :obj:`list`
271-
A list of shapes for each GPU send_buf (used to calculate padding size)
272-
273-
Returns
274-
-------
275-
tuple[send_buf, recv_buf]: :obj:`tuple`
276-
A tuple of (send_buf, recv_buf) will an appropriate size, shape and dtype for NCCL allgather
277-
278-
"""
279-
sizes_each_dim = list(zip(*send_buf_shapes))
280-
send_shape = tuple(map(max, sizes_each_dim))
281-
pad_size = [
282-
(0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, send_buf.shape)
283-
]
284-
285-
send_buf = cp.pad(
286-
send_buf, pad_size, mode="constant", constant_values=0
287-
)
288-
289-
# NCCL recommends to use one MPI Process per GPU and so size of receiving buffer can be inferred
290-
ndev = len(send_buf_shapes)
291-
recv_buf = cp.zeros(ndev * send_buf.size, dtype=send_buf.dtype)
292-
293-
return (send_buf, recv_buf)
294-
295-
296-
def _unroll_nccl_allgather_recv(recv_buf, padded_send_buf_shape, send_buf_shapes) -> list:
297-
""" Remove the padded elements in recv_buff, extract an individual array from each device and return them as a list of arrays
298-
299-
Each GPU may send array with a different shape, so the return type has to be a list of array
300-
instead of the concatenated array.
301-
302-
Parameters
303-
----------
304-
recv_buf: :obj:`cupy.ndarray` or array-like
305-
The data buffer returned from nccl_allgather call
306-
padded_send_buf_shape: :obj:`tuple`:int
307-
The size of send_buf after padding used in nccl_allgather
308-
send_buf_shapes: :obj:`list`
309-
A list of original shapes for each GPU send_buf prior to padding
310-
311-
Returns
312-
-------
313-
chunks: :obj:`list`
314-
A list of `cupy.ndarray` from each GPU with the padded element removed
315-
"""
316-
317-
ndev = len(send_buf_shapes)
318-
# extract an individual array from each device
319-
chunk_size = np.prod(padded_send_buf_shape)
320-
chunks = [
321-
recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev)
322-
]
323-
324-
# Remove padding from each array: the padded value may appear somewhere
325-
# in the middle of the flat array and thus the reshape and slicing for each dimension is required
326-
for i in range(ndev):
327-
slicing = tuple(slice(0, end) for end in send_buf_shapes[i])
328-
chunks[i] = chunks[i].reshape(padded_send_buf_shape)[slicing]
329-
330-
return chunks
331-
332-
333337
def nccl_asarray(nccl_comm, local_array, local_shapes, axis) -> cp.ndarray:
334338
"""Global view of the array
335339
@@ -352,7 +356,7 @@ def nccl_asarray(nccl_comm, local_array, local_shapes, axis) -> cp.ndarray:
352356
Global array gathered from all GPUs and concatenated along `axis`.
353357
"""
354358

355-
(send_buf, recv_buf) = _prepare_nccl_allgather_inputs(local_array, local_shapes)
359+
send_buf, recv_buf = _prepare_nccl_allgather_inputs(local_array, local_shapes)
356360
nccl_allgather(nccl_comm, send_buf, recv_buf)
357361
chunks = _unroll_nccl_allgather_recv(recv_buf, send_buf.shape, local_shapes)
358362

tests_nccl/test_ncclutils_nccl.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""Test basic NCCL functionalities in _nccl
2+
Designed to run with n GPUs (with 1 MPI process per GPU)
3+
$ mpiexec -n 10 pytest test_ncclutils_nccl.py --with-mpi
4+
"""
5+
from mpi4py import MPI
6+
import numpy as np
7+
import cupy as cp
8+
from numpy.testing import assert_allclose
9+
import pytest
10+
11+
from pylops_mpi.utils._nccl import initialize_nccl_comm, nccl_allgather, _prepare_nccl_allgather_inputs, _unroll_nccl_allgather_recv
12+
13+
np.random.seed(42)
14+
15+
nccl_comm = initialize_nccl_comm()
16+
17+
par1 = {'n': 3, 'dtype': np.float64}
18+
19+
20+
@pytest.mark.mpi(min_size=2)
21+
@pytest.mark.parametrize("par", [(par1), ])
22+
def test_allgather_samesize(par):
23+
"""Test nccl_allgather with arrays of same size"""
24+
size = MPI.COMM_WORLD.Get_size()
25+
rank = MPI.COMM_WORLD.Get_rank()
26+
27+
# Local array
28+
local_array = rank * cp.ones(par['n'], dtype=par['dtype'])
29+
30+
# Gathered array
31+
gathered_array = nccl_allgather(nccl_comm, local_array)
32+
33+
# Compare with global array created in rank0
34+
if rank == 0:
35+
global_array = np.ones(par['n'] * size, dtype=par['dtype'])
36+
for irank in range(size):
37+
global_array[irank * par["n"]: (irank + 1) * par["n"]] = irank
38+
39+
assert_allclose(
40+
gathered_array.get(),
41+
global_array,
42+
rtol=1e-14,
43+
)
44+
45+
46+
@pytest.mark.mpi(min_size=2)
47+
@pytest.mark.parametrize("par", [(par1), ])
48+
def test_allgather_samesize_withrecbuf(par):
49+
"""Test nccl_allgather with arrays of same size and rec_buf"""
50+
size = MPI.COMM_WORLD.Get_size()
51+
rank = MPI.COMM_WORLD.Get_rank()
52+
53+
# Local array
54+
local_array = rank * cp.ones(par['n'], dtype=par['dtype'])
55+
56+
# Gathered array
57+
gathered_array = cp.zeros(par['n'] * size, dtype=par['dtype'])
58+
gathered_array = nccl_allgather(nccl_comm, local_array, recv_buf=gathered_array)
59+
60+
# Compare with global array created in rank0
61+
if rank == 0:
62+
global_array = np.ones(par['n'] * size, dtype=par['dtype'])
63+
for irank in range(size):
64+
global_array[irank * par["n"]: (irank + 1) * par["n"]] = irank
65+
66+
assert_allclose(
67+
gathered_array.get(),
68+
global_array,
69+
rtol=1e-14,
70+
)
71+
72+
73+
@pytest.mark.mpi(min_size=2)
74+
@pytest.mark.parametrize("par", [(par1), ])
75+
def test_allgather_differentsize_withrecbuf(par):
76+
"""Test nccl_allgather with arrays of different size and rec_buf"""
77+
size = MPI.COMM_WORLD.Get_size()
78+
rank = MPI.COMM_WORLD.Get_rank()
79+
80+
# Local array
81+
n = par['n'] + (1 if rank == size - 1 else 0)
82+
local_array = rank * cp.ones(n, dtype=par['dtype'])
83+
84+
# Gathered array
85+
send_shapes = MPI.COMM_WORLD.allgather(local_array.shape)
86+
(send_buf, recv_buf) = _prepare_nccl_allgather_inputs(local_array, send_shapes)
87+
recv_buf = nccl_allgather(nccl_comm, send_buf, recv_buf)
88+
chunks = _unroll_nccl_allgather_recv(recv_buf, send_buf.shape, send_shapes)
89+
gathered_array = cp.concatenate(chunks)
90+
91+
# Compare with global array created in rank0
92+
if rank == 0:
93+
global_array = np.ones(par['n'] * size + 1, dtype=par['dtype'])
94+
for irank in range(size - 1):
95+
global_array[irank * par["n"]: (irank + 1) * par["n"]] = irank
96+
global_array[(size - 1) * par["n"]:] = size - 1
97+
98+
assert_allclose(
99+
gathered_array.get(),
100+
global_array,
101+
rtol=1e-14,
102+
)

0 commit comments

Comments
 (0)