Skip to content

Commit a700c63

Browse files
authored
Merge pull request #145 from tharittk/bug-fix-cupy-mpi
fix send/recv bug in CuPy + MPI
2 parents 750d41d + a56456d commit a700c63

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ def _send(self, send_buf, dest, count=None, tag=None):
526526
count = send_buf.size
527527
nccl_send(self.base_comm_nccl, send_buf, dest, count)
528528
else:
529-
self.base_comm.Send(send_buf, dest, tag)
529+
self.base_comm.send(send_buf, dest, tag)
530530

531531
def _recv(self, recv_buf=None, source=0, count=None, tag=None):
532532
""" Receive operation
@@ -543,11 +543,8 @@ def _recv(self, recv_buf=None, source=0, count=None, tag=None):
543543
else:
544544
raise ValueError("Using recv with NCCL must also supply receiver buffer ")
545545
else:
546-
# MPI allows a receiver buffer to be optional
547-
if recv_buf is None:
548-
return self.base_comm.recv(source=source, tag=tag)
549-
self.base_comm.Recv(buf=recv_buf, source=source, tag=tag)
550-
return recv_buf
546+
# MPI allows a receiver buffer to be optional and receives as a Python Object
547+
return self.base_comm.recv(source=source, tag=tag)
551548

552549
def _nccl_local_shapes(self, masked: bool):
553550
"""Get the the list of shapes of every GPU in the communicator

0 commit comments

Comments
 (0)