Skip to content

Commit ae7190c

Browse files
committed
small fixes based on some of PR comments
1 parent 58f1305 commit ae7190c

File tree

2 files changed

+12
-22
lines changed

2 files changed

+12
-22
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
495495
def _allgather(self, send_buf, recv_buf=None):
496496
"""Allgather operation
497497
"""
498-
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
498+
if deps.nccl_enabled and self.base_comm_nccl:
499499
return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf)
500500
else:
501501
if recv_buf is None:
@@ -518,13 +518,16 @@ def _recv(self, recv_buf=None, source=0, count=None, tag=None):
518518
""" Receive operation
519519
"""
520520
# NCCL must be called with recv_buf. Size cannot be inferred from
521-
# other arguments and thus cannot dynamically allocated
521+
# other arguments and thus cannot be dynamically allocated
522522
if deps.nccl_enabled and getattr(self, "base_comm_nccl") and recv_buf is not None:
523-
if count is None:
524-
# assuming data will take a space of the whole buffer
525-
count = recv_buf.size
526-
nccl_recv(self.base_comm_nccl, recv_buf, source, count)
527-
return recv_buf
523+
if recv_buf is not None:
524+
if count is None:
525+
# assuming data will take a space of the whole buffer
526+
count = recv_buf.size
527+
nccl_recv(self.base_comm_nccl, recv_buf, source, count)
528+
return recv_buf
529+
else:
530+
raise ValueError("Using recv with NCCL must also supply receiver buffer ")
528531
else:
529532
# MPI allows a receiver buffer to be optional
530533
if recv_buf is None:
@@ -773,10 +776,9 @@ def add_ghost_cells(self, cells_front: Optional[int] = None,
773776
-------
774777
ghosted_array : :obj:`numpy.ndarray`
775778
Ghosted Array
776-
777779
"""
778780
ghosted_array = self.local_array.copy()
779-
ncp = get_module(getattr(self, "engine"))
781+
ncp = get_module(self.engine)
780782
if cells_front is not None:
781783
# cells_front is small array of int. Explicitly use MPI
782784
total_cells_front = self.base_comm.allgather(cells_front) + [0]
@@ -790,7 +792,7 @@ def add_ghost_cells(self, cells_front: Optional[int] = None,
790792
recv_shape = list(recv_shapes[self.rank - 1])
791793
recv_shape[self.axis] = total_cells_front[self.rank]
792794
recv_buf = ncp.zeros(recv_shape, dtype=ghosted_array.dtype)
793-
# Some communication can skip if len(recv_buf) = 0
795+
# Transfer of ghost cells can be skipped if len(recv_buf) = 0
794796
# Additionally, NCCL will hang if the buffer size is 0 so this optimization is somewhat mandatory
795797
if len(recv_buf) != 0:
796798
ghosted_array = ncp.concatenate([self._recv(recv_buf, source=self.rank - 1, tag=1), ghosted_array], axis=self.axis)

pylops_mpi/utils/_nccl.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,6 @@ def nccl_bcast(nccl_comm, local_array, index, value) -> None:
215215
The index in the array to be broadcasted.
216216
value : :obj:`scalar`
217217
The value to broadcast (only used by the root GPU, rank 0).
218-
219-
Returns
220-
-------
221-
None
222218
"""
223219
if nccl_comm.rank_id() == 0:
224220
local_array[index] = value
@@ -304,10 +300,6 @@ def nccl_send(nccl_comm, send_buf, dest, count):
304300
The rank of the destination GPU device.
305301
count : :obj:`int`
306302
Number of elements to send from `send_buf`.
307-
308-
Returns
309-
-------
310-
None
311303
"""
312304
nccl_comm.send(send_buf.data.ptr,
313305
count,
@@ -331,10 +323,6 @@ def nccl_recv(nccl_comm, recv_buf, source, count=None):
331323
The rank of the source GPU device.
332324
count : :obj:`int`, optional
333325
Number of elements to receive.
334-
335-
Returns
336-
-------
337-
None
338326
"""
339327
nccl_comm.recv(recv_buf.data.ptr,
340328
count,

0 commit comments

Comments
 (0)