Skip to content

Commit 83f7a8b

Browse files
committed
minor: fix flake8
1 parent e0fd716 commit 83f7a8b

File tree

3 files changed

+41
-10
lines changed

3 files changed

+41
-10
lines changed

pylops_mpi/Distributed.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -265,17 +265,45 @@ def _send(self,
265265
send_buf, dest, count, tag=tag,
266266
engine=engine)
267267

268-
def _recv(self, recv_buf=None, source=0, count=None, tag=0):
268+
def _recv(self,
269+
base_comm: MPI.Comm,
270+
base_comm_nccl: NcclCommunicatorType,
271+
recv_buf=None, source=0, count=None, tag=0,
272+
engine: str = "numpy",
273+
) -> NDArray:
269274
"""Receive operation
275+
276+
Parameters
277+
----------
278+
base_comm : :obj:`MPI.Comm`
279+
Base MPI Communicator.
280+
base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`
281+
NCCL Communicator.
282+
recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`, optional
283+
The buffered array to receive data.
284+
source : :obj:`int`
285+
The rank of the sending CPU/GPU device.
286+
count : :obj:`int`
287+
Number of elements to receive.
288+
tag : :obj:`int`
289+
Tag of the message to be sent.
290+
engine : :obj:`str`, optional
291+
Engine used to store array (``numpy`` or ``cupy``)
292+
293+
Returns
294+
-------
295+
recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
296+
The buffer containing the received data.
297+
270298
"""
271-
if deps.nccl_enabled and self.base_comm_nccl:
299+
if deps.nccl_enabled and base_comm_nccl is not None:
272300
if recv_buf is None:
273301
raise ValueError("recv_buf must be supplied when using NCCL")
274302
if count is None:
275303
count = recv_buf.size
276-
nccl_recv(self.base_comm_nccl, recv_buf, source, count)
304+
nccl_recv(base_comm_nccl, recv_buf, source, count)
277305
return recv_buf
278306
else:
279-
return mpi_recv(self.base_comm,
307+
return mpi_recv(base_comm,
280308
recv_buf, source, count, tag=tag,
281-
engine=self.engine)
309+
engine=engine)

pylops_mpi/DistributedArray.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,9 @@ def add_ghost_cells(self, cells_front: Optional[int] = None,
797797
# Transfer of ghost cells can be skipped if len(recv_buf) = 0
798798
# Additionally, NCCL will hang if the buffer size is 0 so this optimization is somewhat mandatory
799799
if len(recv_buf) != 0:
800-
ghosted_array = ncp.concatenate([self._recv(recv_buf, source=self.rank - 1, tag=1), ghosted_array], axis=self.axis)
800+
ghosted_array = ncp.concatenate([self._recv(self.base_comm, self.base_comm_nccl,
801+
recv_buf, source=self.rank - 1, tag=1,
802+
engine=self.engine), ghosted_array], axis=self.axis)
801803
# The skip in sender is to match with what described in receiver
802804
if self.rank != self.size - 1 and len(send_buf) != 0:
803805
if cells_front > self.local_shape[self.axis]:
@@ -831,7 +833,9 @@ def add_ghost_cells(self, cells_front: Optional[int] = None,
831833
recv_shape[self.axis] = total_cells_back[self.rank]
832834
recv_buf = ncp.zeros(recv_shape, dtype=ghosted_array.dtype)
833835
if len(recv_buf) != 0:
834-
ghosted_array = ncp.append(ghosted_array, self._recv(recv_buf, source=self.rank + 1, tag=0),
836+
ghosted_array = ncp.append(ghosted_array, self._recv(self.base_comm, self.base_comm_nccl,
837+
recv_buf, source=self.rank + 1, tag=0,
838+
engine=self.engine),
835839
axis=self.axis)
836840
return ghosted_array
837841

pylops_mpi/utils/_mpi.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from typing import Optional, Union
1010

11-
import numpy as np
1211
from mpi4py import MPI
1312
from pylops.utils import NDArray
1413
from pylops.utils.backend import get_module
@@ -37,12 +36,12 @@ def mpi_allgather(base_comm: MPI.Comm,
3736
a new buffer will be allocated with the appropriate shape.
3837
engine : :obj:`str`, optional
3938
Engine used to store array (``numpy`` or ``cupy``)
40-
39+
4140
Returns
4241
-------
4342
recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
4443
A buffer containing the gathered data from all ranks.
45-
44+
4645
"""
4746
if deps.cuda_aware_mpi_enabled or engine == "numpy":
4847
send_shapes = base_comm.allgather(send_buf.shape)

0 commit comments

Comments
 (0)