Skip to content

Commit 2c8bf53

Browse files
authored
Merge pull request #148 from tharittk/complex-support
Complex-number Support for NCCL & Fredholm NCCL
2 parents 1fdf083 + 4ca10e4 commit 2c8bf53

14 files changed

+617
-377
lines changed

docs/source/gpu.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ In the following, we provide a list of modules (i.e., operators and solvers) whe
153153
* - :class:`pylops_mpi.optimization.basic.cgls`
154154
- ✅
155155
* - :class:`pylops_mpi.signalprocessing.Fredhoml1`
156-
- Planned ⏳
157-
* - ISTA Solver
158-
- Planned ⏳
156+
- ✅
159157
* - Complex Numeric Data Type for NCCL
160-
- Planned ⏳
158+
- ✅
159+
* - ISTA Solver
160+
- Planned ⏳

pylops_mpi/DistributedArray.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
nccl_message = deps.nccl_import("the DistributedArray module")
1515

1616
if nccl_message is None and cupy_message is None:
17-
from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv
17+
from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv, _prepare_nccl_allgather_inputs, _unroll_nccl_allgather_recv
1818
from cupy.cuda.nccl import NcclCommunicator
1919
else:
2020
NcclCommunicator = Any
@@ -500,7 +500,13 @@ def _allgather(self, send_buf, recv_buf=None):
500500
"""Allgather operation
501501
"""
502502
if deps.nccl_enabled and self.base_comm_nccl:
503-
return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf)
503+
if isinstance(send_buf, (tuple, list, int)):
504+
return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf)
505+
else:
506+
send_shapes = self.base_comm.allgather(send_buf.shape)
507+
(padded_send, padded_recv) = _prepare_nccl_allgather_inputs(send_buf, send_shapes)
508+
raw_recv = nccl_allgather(self.base_comm_nccl, padded_send, recv_buf if recv_buf else padded_recv)
509+
return _unroll_nccl_allgather_recv(raw_recv, padded_send.shape, send_shapes)
504510
else:
505511
if recv_buf is None:
506512
return self.base_comm.allgather(send_buf)
@@ -511,7 +517,13 @@ def _allgather_subcomm(self, send_buf, recv_buf=None):
511517
"""Allgather operation with subcommunicator
512518
"""
513519
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
514-
return nccl_allgather(self.sub_comm, send_buf, recv_buf)
520+
if isinstance(send_buf, (tuple, list, int)):
521+
return nccl_allgather(self.sub_comm, send_buf, recv_buf)
522+
else:
523+
send_shapes = self._allgather_subcomm(send_buf.shape)
524+
(padded_send, padded_recv) = _prepare_nccl_allgather_inputs(send_buf, send_shapes)
525+
raw_recv = nccl_allgather(self.sub_comm, padded_send, recv_buf if recv_buf else padded_recv)
526+
return _unroll_nccl_allgather_recv(raw_recv, padded_send.shape, send_shapes)
515527
else:
516528
if recv_buf is None:
517529
return self.sub_comm.allgather(send_buf)

pylops_mpi/signalprocessing/Fredholm1.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,10 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
111111
if x.partition not in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]:
112112
raise ValueError(f"x should have partition={Partition.BROADCAST},{Partition.UNSAFE_BROADCAST}"
113113
f"Got {x.partition} instead...")
114-
y = DistributedArray(global_shape=self.shape[0], partition=x.partition,
114+
y = DistributedArray(global_shape=self.shape[0],
115+
base_comm=x.base_comm,
116+
base_comm_nccl=x.base_comm_nccl,
117+
partition=x.partition,
115118
engine=x.engine, dtype=self.dtype)
116119
x = x.local_array.reshape(self.dims).squeeze()
117120
x = x[self.islstart[self.rank]:self.islend[self.rank]]
@@ -125,15 +128,18 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
125128
for isl in range(self.nsls[self.rank]):
126129
y1[isl] = ncp.dot(self.G[isl], x[isl])
127130
# gather results
128-
y[:] = np.vstack(self.base_comm.allgather(y1)).ravel()
131+
y[:] = ncp.vstack(y._allgather(y1)).ravel()
129132
return y
130133

131134
def _rmatvec(self, x: NDArray) -> NDArray:
132135
ncp = get_module(x.engine)
133136
if x.partition not in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]:
134137
raise ValueError(f"x should have partition={Partition.BROADCAST},{Partition.UNSAFE_BROADCAST}"
135138
f"Got {x.partition} instead...")
136-
y = DistributedArray(global_shape=self.shape[1], partition=x.partition,
139+
y = DistributedArray(global_shape=self.shape[1],
140+
base_comm=x.base_comm,
141+
base_comm_nccl=x.base_comm_nccl,
142+
partition=x.partition,
137143
engine=x.engine, dtype=self.dtype)
138144
x = x.local_array.reshape(self.dimsd).squeeze()
139145
x = x[self.islstart[self.rank]:self.islend[self.rank]]
@@ -159,5 +165,5 @@ def _rmatvec(self, x: NDArray) -> NDArray:
159165
y1[isl] = ncp.dot(x[isl].T.conj(), self.G[isl]).T.conj()
160166

161167
# gather results
162-
y[:] = np.vstack(self.base_comm.allgather(y1)).ravel()
168+
y[:] = ncp.vstack(y._allgather(y1)).ravel()
163169
return y

pylops_mpi/utils/_nccl.py

Lines changed: 115 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
11
__all__ = [
2+
"_prepare_nccl_allgather_inputs",
3+
"_unroll_nccl_allgather_recv",
24
"initialize_nccl_comm",
35
"nccl_split",
46
"nccl_allgather",
57
"nccl_allreduce",
68
"nccl_bcast",
79
"nccl_asarray",
810
"nccl_send",
9-
"nccl_recv"
11+
"nccl_recv",
1012
]
1113

1214
from enum import IntEnum
15+
from typing import Tuple
1316
from mpi4py import MPI
1417
import os
1518
import numpy as np
1619
import cupy as cp
1720
import cupy.cuda.nccl as nccl
1821

22+
1923
cupy_to_nccl_dtype = {
2024
"float32": nccl.NCCL_FLOAT32,
2125
"float64": nccl.NCCL_FLOAT64,
@@ -25,6 +29,9 @@
2529
"int8": nccl.NCCL_INT8,
2630
"uint32": nccl.NCCL_UINT32,
2731
"uint64": nccl.NCCL_UINT64,
32+
# sending complex array as float with 2x size
33+
"complex64": nccl.NCCL_FLOAT32,
34+
"complex128": nccl.NCCL_FLOAT64,
2835
}
2936

3037

@@ -35,6 +42,106 @@ class NcclOp(IntEnum):
3542
MIN = nccl.NCCL_MIN
3643

3744

45+
def _nccl_buf_size(buf, count=None):
46+
""" Get an appropriate buffer size according to the dtype of buf
47+
48+
Parameters
49+
----------
50+
buf : :obj:`cupy.ndarray` or array-like
51+
The data buffer from the local GPU to be sent.
52+
53+
count : :obj:`int`, optional
54+
Number of elements to send from `buf`, if not sending the every element in `buf`.
55+
Returns:
56+
-------
57+
:obj:`int`
58+
An appropriate number of elements to send from `send_buf` for NCCL communication.
59+
"""
60+
if buf.dtype in ['complex64', 'complex128']:
61+
return 2 * count if count else 2 * buf.size
62+
else:
63+
return count if count else buf.size
64+
65+
66+
def _prepare_nccl_allgather_inputs(send_buf, send_buf_shapes) -> Tuple[cp.ndarray, cp.ndarray]:
67+
r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather)
68+
69+
NCCL's allGather requires the sending buffer to have the same size for every device.
70+
Therefore, padding is required when the array is not evenly partitioned across
71+
all the ranks. The padding is applied such that the each dimension of the sending buffers
72+
is equal to the max size of that dimension across all ranks.
73+
74+
Similarly, each receiver buffer (recv_buf) is created with size equal to :math:n_rank \cdot send_buf.size
75+
76+
Parameters
77+
----------
78+
send_buf : :obj:`cupy.ndarray` or array-like
79+
The data buffer from the local GPU to be sent for allgather.
80+
send_buf_shapes: :obj:`list`
81+
A list of shapes for each GPU send_buf (used to calculate padding size)
82+
83+
Returns
84+
-------
85+
send_buf: :obj:`cupy.ndarray`
86+
A buffer containing the data and padded elements to be sent by this rank.
87+
recv_buf : :obj:`cupy.ndarray`
88+
An empty, padded buffer to gather data from all GPUs.
89+
"""
90+
sizes_each_dim = list(zip(*send_buf_shapes))
91+
send_shape = tuple(map(max, sizes_each_dim))
92+
pad_size = [
93+
(0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, send_buf.shape)
94+
]
95+
96+
send_buf = cp.pad(
97+
send_buf, pad_size, mode="constant", constant_values=0
98+
)
99+
100+
# NCCL recommends to use one MPI Process per GPU and so size of receiving buffer can be inferred
101+
ndev = len(send_buf_shapes)
102+
recv_buf = cp.zeros(ndev * send_buf.size, dtype=send_buf.dtype)
103+
104+
return send_buf, recv_buf
105+
106+
107+
def _unroll_nccl_allgather_recv(recv_buf, padded_send_buf_shape, send_buf_shapes) -> list:
108+
"""Unrolll recv_buf after NCCL allgather (nccl_allgather)
109+
110+
Remove the padded elements in recv_buff, extract an individual array from each device and return them as a list of arrays
111+
Each GPU may send array with a different shape, so the return type has to be a list of array
112+
instead of the concatenated array.
113+
114+
Parameters
115+
----------
116+
recv_buf: :obj:`cupy.ndarray` or array-like
117+
The data buffer returned from nccl_allgather call
118+
padded_send_buf_shape: :obj:`tuple`:int
119+
The size of send_buf after padding used in nccl_allgather
120+
send_buf_shapes: :obj:`list`
121+
A list of original shapes for each GPU send_buf prior to padding
122+
123+
Returns
124+
-------
125+
chunks: :obj:`list`
126+
A list of `cupy.ndarray` from each GPU with the padded element removed
127+
"""
128+
129+
ndev = len(send_buf_shapes)
130+
# extract an individual array from each device
131+
chunk_size = np.prod(padded_send_buf_shape)
132+
chunks = [
133+
recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev)
134+
]
135+
136+
# Remove padding from each array: the padded value may appear somewhere
137+
# in the middle of the flat array and thus the reshape and slicing for each dimension is required
138+
for i in range(ndev):
139+
slicing = tuple(slice(0, end) for end in send_buf_shapes[i])
140+
chunks[i] = chunks[i].reshape(padded_send_buf_shape)[slicing]
141+
142+
return chunks
143+
144+
38145
def mpi_op_to_nccl(mpi_op) -> NcclOp:
39146
""" Map MPI reduction operation to NCCL equivalent
40147
@@ -155,7 +262,7 @@ def nccl_allgather(nccl_comm, send_buf, recv_buf=None) -> cp.ndarray:
155262
nccl_comm.allGather(
156263
send_buf.data.ptr,
157264
recv_buf.data.ptr,
158-
send_buf.size,
265+
_nccl_buf_size(send_buf),
159266
cupy_to_nccl_dtype[str(send_buf.dtype)],
160267
cp.cuda.Stream.null.ptr,
161268
)
@@ -193,7 +300,7 @@ def nccl_allreduce(nccl_comm, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM) ->
193300
nccl_comm.allReduce(
194301
send_buf.data.ptr,
195302
recv_buf.data.ptr,
196-
send_buf.size,
303+
_nccl_buf_size(send_buf),
197304
cupy_to_nccl_dtype[str(send_buf.dtype)],
198305
mpi_op_to_nccl(op),
199306
cp.cuda.Stream.null.ptr,
@@ -220,7 +327,7 @@ def nccl_bcast(nccl_comm, local_array, index, value) -> None:
220327
local_array[index] = value
221328
nccl_comm.bcast(
222329
local_array[index].data.ptr,
223-
local_array[index].size,
330+
_nccl_buf_size(local_array[index]),
224331
cupy_to_nccl_dtype[str(local_array[index].dtype)],
225332
0,
226333
cp.cuda.Stream.null.ptr,
@@ -247,41 +354,12 @@ def nccl_asarray(nccl_comm, local_array, local_shapes, axis) -> cp.ndarray:
247354
-------
248355
final_array : :obj:`cupy.ndarray`
249356
Global array gathered from all GPUs and concatenated along `axis`.
250-
251-
Notes
252-
-----
253-
NCCL's allGather requires the sending buffer to have the same size for every device.
254-
Therefore, the padding is required when the array is not evenly partitioned across
255-
all the ranks. The padding is applied such that the sending buffer has the size of
256-
each dimension corresponding to the max possible size of that dimension.
257357
"""
258-
sizes_each_dim = list(zip(*local_shapes))
259-
260-
send_shape = tuple(map(max, sizes_each_dim))
261-
pad_size = [
262-
(0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, local_array.shape)
263-
]
264358

265-
send_buf = cp.pad(
266-
local_array, pad_size, mode="constant", constant_values=0
267-
)
268-
269-
# NCCL recommends to use one MPI Process per GPU and so size of receiving buffer can be inferred
270-
ndev = len(local_shapes)
271-
recv_buf = cp.zeros(ndev * send_buf.size, dtype=send_buf.dtype)
359+
send_buf, recv_buf = _prepare_nccl_allgather_inputs(local_array, local_shapes)
272360
nccl_allgather(nccl_comm, send_buf, recv_buf)
361+
chunks = _unroll_nccl_allgather_recv(recv_buf, send_buf.shape, local_shapes)
273362

274-
# extract an individual array from each device
275-
chunk_size = np.prod(send_shape)
276-
chunks = [
277-
recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev)
278-
]
279-
280-
# Remove padding from each array: the padded value may appear somewhere
281-
# in the middle of the flat array and thus the reshape and slicing for each dimension is required
282-
for i in range(ndev):
283-
slicing = tuple(slice(0, end) for end in local_shapes[i])
284-
chunks[i] = chunks[i].reshape(send_shape)[slicing]
285363
# combine back to single global array
286364
return cp.concatenate(chunks, axis=axis)
287365

@@ -302,7 +380,7 @@ def nccl_send(nccl_comm, send_buf, dest, count):
302380
Number of elements to send from `send_buf`.
303381
"""
304382
nccl_comm.send(send_buf.data.ptr,
305-
count,
383+
_nccl_buf_size(send_buf, count),
306384
cupy_to_nccl_dtype[str(send_buf.dtype)],
307385
dest,
308386
cp.cuda.Stream.null.ptr
@@ -325,7 +403,7 @@ def nccl_recv(nccl_comm, recv_buf, source, count=None):
325403
Number of elements to receive.
326404
"""
327405
nccl_comm.recv(recv_buf.data.ptr,
328-
count,
406+
_nccl_buf_size(recv_buf, count),
329407
cupy_to_nccl_dtype[str(recv_buf.dtype)],
330408
source,
331409
cp.cuda.Stream.null.ptr

tests_nccl/test_blockdiag_nccl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@
1818
nccl_comm = initialize_nccl_comm()
1919

2020
par1 = {'ny': 101, 'nx': 101, 'dtype': np.float64}
21-
# par1j = {'ny': 101, 'nx': 101, 'dtype': np.complex128}
21+
par1j = {'ny': 101, 'nx': 101, 'dtype': np.complex128}
2222
par2 = {'ny': 301, 'nx': 101, 'dtype': np.float64}
23-
# par2j = {'ny': 301, 'nx': 101, 'dtype': np.complex128}
23+
par2j = {'ny': 301, 'nx': 101, 'dtype': np.complex128}
2424

2525
np.random.seed(42)
2626

2727

2828
@pytest.mark.mpi(min_size=2)
29-
@pytest.mark.parametrize("par", [(par1), (par2)])
29+
@pytest.mark.parametrize("par", [(par1), (par1j), (par2), (par2j)])
3030
def test_blockdiag_nccl(par):
3131
"""Test the MPIBlockDiag with NCCL"""
3232
size = MPI.COMM_WORLD.Get_size()
@@ -71,7 +71,7 @@ def test_blockdiag_nccl(par):
7171

7272

7373
@pytest.mark.mpi(min_size=2)
74-
@pytest.mark.parametrize("par", [(par1), (par2)])
74+
@pytest.mark.parametrize("par", [(par1), (par1j), (par2), (par2j)])
7575
def test_stacked_blockdiag_nccl(par):
7676
"""Tests for MPIStackedBlogDiag with NCCL"""
7777
size = MPI.COMM_WORLD.Get_size()

0 commit comments

Comments
 (0)