diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index a0e4eab7..8125f483 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -526,7 +526,7 @@ def _send(self, send_buf, dest, count=None, tag=None): count = send_buf.size nccl_send(self.base_comm_nccl, send_buf, dest, count) else: - self.base_comm.Send(send_buf, dest, tag) + self.base_comm.send(send_buf, dest, tag) def _recv(self, recv_buf=None, source=0, count=None, tag=None): """ Receive operation @@ -543,11 +543,8 @@ def _recv(self, recv_buf=None, source=0, count=None, tag=None): else: raise ValueError("Using recv with NCCL must also supply receiver buffer ") else: - # MPI allows a receiver buffer to be optional - if recv_buf is None: - return self.base_comm.recv(source=source, tag=tag) - self.base_comm.Recv(buf=recv_buf, source=source, tag=tag) - return recv_buf + # MPI allows a receiver buffer to be optional and receives as a Python Object + return self.base_comm.recv(source=source, tag=tag) def _nccl_local_shapes(self, masked: bool): """Get the the list of shapes of every GPU in the communicator