Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/source/gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ In the following, we provide a list of modules (i.e., operators and solvers) whe
* - :class:`pylops_mpi.optimization.basic.cgls`
- ✅
* - :class:`pylops_mpi.signalprocessing.Fredhoml1`
- Planned ⏳
* - ISTA Solver
- Planned ⏳
- ✅
* - Complex Numeric Data Type for NCCL
- Planned ⏳
- ✅
* - ISTA Solver
- Planned ⏳
24 changes: 20 additions & 4 deletions pylops_mpi/signalprocessing/Fredholm1.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,10 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
if x.partition not in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]:
raise ValueError(f"x should have partition={Partition.BROADCAST},{Partition.UNSAFE_BROADCAST}"
f"Got {x.partition} instead...")
y = DistributedArray(global_shape=self.shape[0], partition=x.partition,
y = DistributedArray(global_shape=self.shape[0],
base_comm=x.base_comm,
base_comm_nccl=x.base_comm_nccl,
partition=x.partition,
engine=x.engine, dtype=self.dtype)
x = x.local_array.reshape(self.dims).squeeze()
x = x[self.islstart[self.rank]:self.islend[self.rank]]
Expand All @@ -125,15 +128,18 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
for isl in range(self.nsls[self.rank]):
y1[isl] = ncp.dot(self.G[isl], x[isl])
# gather results
y[:] = np.vstack(self.base_comm.allgather(y1)).ravel()
y[:] = ncp.vstack(y._allgather(y1)).ravel()
return y

def _rmatvec(self, x: NDArray) -> NDArray:
ncp = get_module(x.engine)
if x.partition not in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]:
raise ValueError(f"x should have partition={Partition.BROADCAST},{Partition.UNSAFE_BROADCAST}"
f"Got {x.partition} instead...")
y = DistributedArray(global_shape=self.shape[1], partition=x.partition,
y = DistributedArray(global_shape=self.shape[1],
base_comm=x.base_comm,
base_comm_nccl=x.base_comm_nccl,
partition=x.partition,
engine=x.engine, dtype=self.dtype)
x = x.local_array.reshape(self.dimsd).squeeze()
x = x[self.islstart[self.rank]:self.islend[self.rank]]
Expand All @@ -159,5 +165,15 @@ def _rmatvec(self, x: NDArray) -> NDArray:
y1[isl] = ncp.dot(x[isl].T.conj(), self.G[isl]).T.conj()

# gather results
y[:] = np.vstack(self.base_comm.allgather(y1)).ravel()
recv = y._allgather(y1)
# TODO: current of _allgather will call non-buffered MPI-AllGather (sub-optimal for CuPy+MPI)
# which returns a list (not flatten) and does not require unrolling
if self.usematmul and isinstance(recv, ncp.ndarray) :
# unrolling
chunk_size = self.ny * self.nz
num_partition = (len(recv) + chunk_size - 1) // chunk_size
recv = ncp.vstack([recv[i * chunk_size: (i + 1) * chunk_size].reshape(self.nz, self.ny).T for i in range(num_partition)])
else:
recv = ncp.vstack(recv)
y[:] = recv.ravel()
return y
34 changes: 29 additions & 5 deletions pylops_mpi/utils/_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
"int8": nccl.NCCL_INT8,
"uint32": nccl.NCCL_UINT32,
"uint64": nccl.NCCL_UINT64,
# sending complex array as float with 2x size
"complex64": nccl.NCCL_FLOAT32,
"complex128": nccl.NCCL_FLOAT64,
}


Expand All @@ -35,6 +38,27 @@ class NcclOp(IntEnum):
MIN = nccl.NCCL_MIN


def _nccl_buf_size(buf, count=None):
""" Get an appropriate buffer size according to the dtype of buf

Parameters
----------
buf : :obj:`cupy.ndarray` or array-like
The data buffer from the local GPU to be sent.

count : :obj:`int`, optional
Number of elements to send from `buf`, if not sending the every element in `buf`.
Returns:
-------
:obj:`int`
An appropriate number of elements to send from `send_buf` for NCCL communication.
"""
if buf.dtype in ['complex64', 'complex128']:
return 2 * count if count else 2 * buf.size
else:
return count if count else buf.size


def mpi_op_to_nccl(mpi_op) -> NcclOp:
""" Map MPI reduction operation to NCCL equivalent

Expand Down Expand Up @@ -155,7 +179,7 @@ def nccl_allgather(nccl_comm, send_buf, recv_buf=None) -> cp.ndarray:
nccl_comm.allGather(
send_buf.data.ptr,
recv_buf.data.ptr,
send_buf.size,
_nccl_buf_size(send_buf),
cupy_to_nccl_dtype[str(send_buf.dtype)],
cp.cuda.Stream.null.ptr,
)
Expand Down Expand Up @@ -193,7 +217,7 @@ def nccl_allreduce(nccl_comm, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM) ->
nccl_comm.allReduce(
send_buf.data.ptr,
recv_buf.data.ptr,
send_buf.size,
_nccl_buf_size(send_buf),
cupy_to_nccl_dtype[str(send_buf.dtype)],
mpi_op_to_nccl(op),
cp.cuda.Stream.null.ptr,
Expand All @@ -220,7 +244,7 @@ def nccl_bcast(nccl_comm, local_array, index, value) -> None:
local_array[index] = value
nccl_comm.bcast(
local_array[index].data.ptr,
local_array[index].size,
_nccl_buf_size(local_array[index]),
cupy_to_nccl_dtype[str(local_array[index].dtype)],
0,
cp.cuda.Stream.null.ptr,
Expand Down Expand Up @@ -302,7 +326,7 @@ def nccl_send(nccl_comm, send_buf, dest, count):
Number of elements to send from `send_buf`.
"""
nccl_comm.send(send_buf.data.ptr,
count,
_nccl_buf_size(send_buf, count),
cupy_to_nccl_dtype[str(send_buf.dtype)],
dest,
cp.cuda.Stream.null.ptr
Expand All @@ -325,7 +349,7 @@ def nccl_recv(nccl_comm, recv_buf, source, count=None):
Number of elements to receive.
"""
nccl_comm.recv(recv_buf.data.ptr,
count,
_nccl_buf_size(recv_buf, count),
cupy_to_nccl_dtype[str(recv_buf.dtype)],
source,
cp.cuda.Stream.null.ptr
Expand Down
8 changes: 4 additions & 4 deletions tests_nccl/test_blockdiag_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
nccl_comm = initialize_nccl_comm()

par1 = {'ny': 101, 'nx': 101, 'dtype': np.float64}
# par1j = {'ny': 101, 'nx': 101, 'dtype': np.complex128}
par1j = {'ny': 101, 'nx': 101, 'dtype': np.complex128}
par2 = {'ny': 301, 'nx': 101, 'dtype': np.float64}
# par2j = {'ny': 301, 'nx': 101, 'dtype': np.complex128}
par2j = {'ny': 301, 'nx': 101, 'dtype': np.complex128}

np.random.seed(42)


@pytest.mark.mpi(min_size=2)
@pytest.mark.parametrize("par", [(par1), (par2)])
@pytest.mark.parametrize("par", [(par1), (par1j), (par2), (par2j)])
def test_blockdiag_nccl(par):
"""Test the MPIBlockDiag with NCCL"""
size = MPI.COMM_WORLD.Get_size()
Expand Down Expand Up @@ -71,7 +71,7 @@ def test_blockdiag_nccl(par):


@pytest.mark.mpi(min_size=2)
@pytest.mark.parametrize("par", [(par1), (par2)])
@pytest.mark.parametrize("par", [(par1), (par1j), (par2), (par2j)])
def test_stacked_blockdiag_nccl(par):
"""Tests for MPIStackedBlogDiag with NCCL"""
size = MPI.COMM_WORLD.Get_size()
Expand Down
Loading